Spaces:
Runtime error
Runtime error
| # copied from: https://github.com/NVIDIA/TensorRT-LLM/blob/v0.18.1/examples/enc_dec/convert_checkpoint.py | |
| import argparse | |
| import configparser | |
| import copy | |
| import json | |
| import logging | |
| import os | |
| import types | |
| from ast import literal_eval | |
| from datetime import datetime | |
| from pathlib import Path | |
| import safetensors | |
| from helper import convert_weight_to_dtype, fuse_qkv_one_layer, reshape, split | |
| from transformers import (AutoModelForSeq2SeqLM, Blip2ForConditionalGeneration, | |
| MBartForConditionalGeneration, | |
| Pix2StructForConditionalGeneration, | |
| T5ForConditionalGeneration, VisionEncoderDecoderModel) | |
| from tensorrt_llm.functional import (LayerNormPositionType, LayerNormType, | |
| MLPType) | |
| from tensorrt_llm.models import PretrainedConfig | |
| dir_path = os.path.dirname(os.path.realpath(__file__)) | |
| LOGGER = logging.getLogger(__name__) | |
| layernorm_type_map = {i.name: i.value for i in LayerNormType} | |
| layernorm_position_map = {i.name: i.value for i in LayerNormPositionType} | |
| mlp_type_map = {i.name: i.value for i in MLPType} | |
| def copy_args_to_component_config(component_config, args): | |
| for arg in vars(args): | |
| setattr(component_config, arg, getattr(args, arg)) | |
| return component_config | |
| def parse_t5_config(args, hf_model): | |
| config = configparser.ConfigParser() | |
| config["encoder"] = {} | |
| for key, val in hf_model.encoder.config.to_dict().items(): | |
| config["encoder"][key] = f"{val}" | |
| # manually set q_scaling to offset attention scaling's effect. | |
| # TODO: modify kernels to control whether to disable attention scaling | |
| def get_offset_q_scaling(config): | |
| scaling = 1 / config.head_size**.5 | |
| return scaling | |
| config["decoder"] = {} | |
| for key, val in hf_model.decoder.config.to_dict().items(): | |
| config["decoder"][key] = f"{val}" | |
| config["structure"] = dict() | |
| config["structure"]["t5_with_bias"] = "false" | |
| config["structure"]["use_gated_activation"] = str( | |
| hf_model.encoder.config.is_gated_act) | |
| config["structure"]["position_embedding_type"] = "relative" | |
| config["structure"]["model_type"] = args.model_type | |
| def parse_t5_config_by_component(config, component, args): | |
| component_config = types.SimpleNamespace() | |
| component_config = copy_args_to_component_config(component_config, args) | |
| component_config.n_head = config.getint(component, 'num_heads') | |
| component_config.head_size = config.getint(component, 'd_kv') | |
| component_config.hidden_size = config.getint(component, 'd_model') | |
| component_config.ffn_hidden_size = config.getint(component, 'd_ff') | |
| component_config.vocab_size = config.getint(component, 'vocab_size') | |
| component_config.n_positions = config.getint(component, | |
| 'n_positions', | |
| fallback=512) | |
| component_config.has_position_embedding = config.getboolean( | |
| component, 'has_position_embedding', | |
| fallback=False) # TODO: hardcoded here | |
| component_config.has_token_type_embedding = config.getboolean( | |
| component, 'has_token_type_embedding', fallback=False) | |
| component_config.has_embedding_layernorm = config.getboolean( | |
| component, 'has_embedding_layernorm', fallback=False) | |
| component_config.has_embedding_scale = config.getboolean( | |
| component, 'has_embedding_scale', fallback=False) | |
| component_config.q_scaling = get_offset_q_scaling(component_config) | |
| component_config.has_attention_qkvo_bias = config.getboolean( | |
| component, 'has_attention_qkvo_bias', | |
| fallback=False) # TODO: hardcoded here | |
| component_config.has_mlp_bias = config.getboolean(component, | |
| 'has_mlp_bias', | |
| fallback=False) | |
| component_config.has_model_final_layernorm = config.getboolean( | |
| component, 'has_model_final_layernorm', fallback=True) | |
| component_config.layernorm_eps = config.getfloat( | |
| component, 'layer_norm_epsilon') | |
| component_config.layernorm_position = layernorm_position_map[config.get( | |
| component, 'layernorm_position', | |
| fallback='pre_layernorm')] # TODO: hardcoded here | |
| component_config.layernorm_type = layernorm_type_map[config.get( | |
| component, 'layernorm_type', fallback='RmsNorm')] | |
| component_config.hidden_act = config.get(component, 'dense_act_fn') | |
| component_config.gated_act = config.getboolean(component, | |
| 'is_gated_act') | |
| component_config.mlp_type = mlp_type_map['GatedMLP' if component_config. | |
| gated_act else 'MLP'] | |
| component_config.num_buckets = config.getint( | |
| component, 'relative_attention_num_buckets') | |
| component_config.max_distance = config.getint( | |
| component, 'relative_attention_max_distance') | |
| component_config.position_embedding_type = config.get( | |
| 'structure', 'position_embedding_type') | |
| component_config.logits_dtype = config.get(component, | |
| 'logits_dtype', | |
| fallback='float32') | |
| if component == 'encoder': | |
| component_config.n_layer = config.getint(component, 'num_layers') | |
| component_config.relative_attention = config.get( | |
| 'structure', 'position_embedding_type') == 'relative' | |
| elif component == 'decoder': | |
| component_config.n_layer = config.getint(component, | |
| 'num_decoder_layers') | |
| component_config.has_lm_head_bias = config.getboolean( | |
| component, # TODO: T5 with bias | |
| 'has_lm_head_bias', | |
| fallback=False) | |
| component_config.relative_attention = config.getboolean( | |
| component, 'relative_attention', fallback=True) | |
| component_config.rescale_before_lm_head = config.getboolean( | |
| component, 'tie_word_embeddings' | |
| ) # default is True (for T5), but False for Flan-T5 | |
| component_config.encoder_hidden_size = config.getint( | |
| 'encoder', 'd_model') | |
| component_config.encoder_num_heads = config.getint( | |
| 'encoder', 'num_heads') | |
| component_config.encoder_head_size = config.getint( | |
| 'encoder', 'd_kv') | |
| component_config.decoder_start_token_id = config.getint( | |
| 'decoder', 'decoder_start_token_id') | |
| component_config.eos_token_id = config.getint( | |
| 'decoder', 'eos_token_id') | |
| bos_token_id = config.get('decoder', 'bos_token_id') | |
| # T5 does not have bos_token_id | |
| component_config.bos_token_id = int( | |
| bos_token_id) if bos_token_id != "None" else None | |
| component_config.pad_token_id = config.getint( | |
| 'decoder', 'pad_token_id') | |
| else: | |
| assert False, 'Unsupported component!' | |
| return component_config | |
| encoder_config = parse_t5_config_by_component(config, "encoder", args) | |
| decoder_config = parse_t5_config_by_component(config, "decoder", args) | |
| return encoder_config, decoder_config | |
| def convert_t5_weights_to_tllm_safetensors(config, component, params): | |
| weights = {} | |
| mapping = config.mapping | |
| convert_weight_to_dtype(params, config.dtype) | |
| hidden_size = config.hidden_size | |
| ffn_hidden_size = config.intermediate_size | |
| num_layers = config.num_hidden_layers | |
| n_head = config.num_attention_heads | |
| head_size = config.head_size | |
| attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5 | |
| hf_param_prefix = f'{component}' | |
| trtllm_layer_name = f'{component}_layers' | |
| trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention' | |
| trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm' | |
| hf_component_idx = 1 if component == 'encoder' else 2 | |
| def get_attn_module_name(component, block, layer, attn_type): | |
| return f'{component}.block.{int(block)}.layer.{int(layer)}.{attn_type}' | |
| weights['embedding.vocab_embedding.weight'] = reshape( | |
| params['shared.weight'].clone(), None) | |
| layers_range = mapping.pp_layers(num_layers) | |
| for layer_idx in layers_range: | |
| local_layer_idx = layer_idx - layers_range[0] | |
| trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' | |
| hf_layer_name_prefix = f'{hf_param_prefix}.block.{layer_idx}' | |
| hidden_layer_name_split = { | |
| f'{hf_layer_name_prefix}.layer.0.SelfAttention.o.weight': { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight', | |
| "shape": | |
| (hidden_size, attention_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wo.weight': | |
| { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight', | |
| "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi.weight': | |
| { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_0.weight': | |
| { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| } | |
| hidden_layer_name_no_split = { | |
| f'{hf_layer_name_prefix}.layer.0.layer_norm.weight': { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight', | |
| "shape": None | |
| }, | |
| f'{hf_layer_name_prefix}.layer.{hf_component_idx}.layer_norm.weight': | |
| { | |
| "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight', | |
| "shape": None | |
| }, | |
| } | |
| if config.gated_act: | |
| hidden_layer_name_split.update({ | |
| f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi2.weight': | |
| { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| f'{hf_layer_name_prefix}.layer.{hf_component_idx}.DenseReluDense.wi_1.weight': | |
| { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| }) | |
| if component == 'decoder': | |
| hidden_layer_name_split.update({ | |
| f'{hf_layer_name_prefix}.layer.1.EncDecAttention.o.weight': { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.cross_attention.dense.weight', | |
| "shape": | |
| (hidden_size, attention_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| }) | |
| hidden_layer_name_no_split.update({ | |
| f'{hf_layer_name_prefix}.layer.1.layer_norm.weight': { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight', | |
| "shape": None | |
| }, | |
| }) | |
| self_attn_module_name = get_attn_module_name( | |
| component, layer_idx, "1", 'EncDecAttention') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, self_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.cross_attention', | |
| mapping.tp_size, mapping.tp_rank, config.model_type, | |
| (attention_hidden_size * 3 // mapping.tp_size, hidden_size), | |
| None)) | |
| self_attn_module_name = get_attn_module_name(component, layer_idx, "0", | |
| 'SelfAttention') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, self_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', | |
| mapping.tp_size, mapping.tp_rank, config.model_type, | |
| (attention_hidden_size * 3 // mapping.tp_size, hidden_size), | |
| None)) | |
| weights[ | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape( | |
| split( | |
| params[ | |
| f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight'] | |
| .T, mapping.tp_size, mapping.tp_rank, 0), | |
| (n_head // mapping.tp_size, config.num_buckets)) | |
| for hf_weight_name, weight_info in hidden_layer_name_split.items(): | |
| if hf_weight_name in params.keys(): | |
| weights[weight_info["name"]] = reshape( | |
| split(params[hf_weight_name], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=weight_info["split_dim"]), weight_info["shape"]) | |
| for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): | |
| if hf_weight_name in params.keys(): | |
| weights[weight_info["name"]] = reshape( | |
| params[hf_weight_name].clone(), shape=weight_info["shape"]) | |
| weights['final_layernorm.weight'] = reshape( | |
| params[f'{component}.final_layer_norm.weight'].clone(), None) | |
| if component == 'decoder': | |
| weights['lm_head.weight'] = reshape( | |
| split(params['lm_head.weight'], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) | |
| if not config.use_implicit_relative_attention: | |
| weights['rel_attn_table'] = reshape( | |
| split( | |
| params[ | |
| f'{component}.block.0.layer.0.SelfAttention.relative_attention_bias.weight'] | |
| .T, mapping.tp_size, mapping.tp_rank, 0), | |
| (n_head // mapping.tp_size, config.num_buckets)) | |
| return weights | |
| convert_blip2_weights_to_tllm_safetensors = convert_t5_weights_to_tllm_safetensors # func alias | |
| def parse_nmt_config(args, model): | |
| config = configparser.ConfigParser() | |
| fairseq_config = vars(model.cfg.model) # Namespace --> dict | |
| config['encoder'] = dict() | |
| for key, val in fairseq_config.items(): | |
| config["encoder"][key] = f"{val}" | |
| config["encoder"]["q_scaling"] = '1' | |
| # NMT has final layernorm for pre-norm model architecture. | |
| config['encoder']['has_model_final_layernorm'] = config['encoder'][ | |
| 'encoder_normalize_before'] | |
| config['encoder']['vocab_size'] = str(len(model.src_dict)) # fairseq naming | |
| config['decoder'] = dict() | |
| for key, val in fairseq_config.items(): | |
| config["decoder"][key] = f"{val}" | |
| config["decoder"]["q_scaling"] = '1' | |
| config["decoder"]["rescale_before_lm_head"] = 'false' | |
| config['decoder']['has_model_final_layernorm'] = str( | |
| config['decoder'].getboolean('decoder_normalize_before', False) | |
| and not config['decoder'].getboolean('no_decoder_final_norm', False)) | |
| config['decoder']['vocab_size'] = str(len(model.tgt_dict)) # fairseq naming | |
| config["structure"] = dict() | |
| config["structure"]["t5_with_bias"] = "true" | |
| config["structure"]["use_gated_activation"] = "false" | |
| config["structure"][ | |
| "position_embedding_type"] = "learned_absolute" # "sinusoid" | |
| config["structure"]["model_type"] = args.model_type | |
| def parse_nmt_config_by_component(config, component, args): | |
| assert component in ('encoder', 'decoder'), 'Unsupported component!' | |
| component_config = types.SimpleNamespace() | |
| component_config = copy_args_to_component_config(component_config, args) | |
| component_config.n_layer = config.getint(component, | |
| f'{component}_layers') | |
| component_config.n_head = config.getint(component, | |
| f'{component}_attention_heads') | |
| component_config.hidden_size = config.getint( | |
| component, f'{component}_embed_dim') # fairseq naming | |
| component_config.head_size = config.getint( | |
| component, | |
| 'd_kv', | |
| fallback=component_config.hidden_size // component_config.n_head) | |
| component_config.ffn_hidden_size = config.getint( | |
| component, f'{component}_ffn_embed_dim') # fairseq naming | |
| component_config.vocab_size = config.getint(component, 'vocab_size') | |
| component_config.n_positions = config.getint( | |
| component, 'max_source_positions') # fairseq naming | |
| component_config.has_position_embedding = not config.getboolean( | |
| component, 'no_token_positional_embeddings', | |
| fallback=False) # fairseq naming | |
| component_config.has_token_type_embedding = config.getboolean( | |
| component, 'has_token_type_embedding', fallback=False) | |
| component_config.has_embedding_layernorm = config.getboolean( | |
| component, 'layernorm_embedding', fallback=True) # fairseq naming | |
| component_config.has_embedding_scale = not config.getboolean( | |
| component, 'no_scale_embedding') # fairseq naming | |
| component_config.q_scaling = config.getfloat(component, | |
| 'q_scaling', | |
| fallback=1.0) | |
| component_config.has_attention_qkvo_bias = config.getboolean( | |
| 'structure', 't5_with_bias', fallback=True) | |
| component_config.has_mlp_bias = config.getboolean('structure', | |
| 't5_with_bias', | |
| fallback=True) | |
| component_config.has_model_final_layernorm = config.getboolean( | |
| component, 'has_model_final_layernorm') | |
| component_config.layernorm_eps = config.getfloat( | |
| component, 'layer_norm_epsilon', fallback=1e-5) # fairseq naming | |
| normalize_before = config.getboolean( | |
| component, f'{component}_normalize_before') # fairseq naming | |
| component_config.layernorm_position = layernorm_position_map[ | |
| 'pre_layernorm' if normalize_before else 'post_layernorm'] | |
| component_config.layernorm_type = layernorm_type_map[config.get( | |
| component, 'layernorm_type', fallback='LayerNorm')] | |
| component_config.hidden_act = config.get( | |
| component, 'activation_fn') # fairseq naming | |
| component_config.gated_act = config.getboolean(component, | |
| 'is_gated_act', | |
| fallback=False) | |
| component_config.mlp_type = mlp_type_map['GatedMLP' if component_config. | |
| gated_act else 'MLP'] | |
| component_config.relative_attention = config.get( | |
| 'structure', 'position_embedding_type') == 'relative' | |
| component_config.num_buckets = config.getint( | |
| component, 'relative_attention_num_buckets', fallback=0) | |
| component_config.max_distance = config.getint( | |
| component, 'relative_attention_max_distance', fallback=0) | |
| component_config.position_embedding_type = config.get( | |
| 'structure', 'position_embedding_type') | |
| component_config.logits_dtype = config.get(component, | |
| 'logits_dtype', | |
| fallback='float32') | |
| if component == 'decoder': | |
| component_config.rescale_before_lm_head = config.getboolean( | |
| component, 'rescale_before_lm_head') | |
| component_config.encoder_hidden_size = config.getint( | |
| 'encoder', 'encoder_embed_dim') # fairseq naming | |
| component_config.encoder_num_heads = config.getint( | |
| 'encoder', 'encoder_attention_heads') | |
| component_config.encoder_head_size = config.getint( | |
| 'encoder', | |
| 'd_kv', | |
| fallback=component_config.encoder_hidden_size // | |
| component_config.encoder_num_heads) | |
| component_config.decoder_start_token_id = None | |
| component_config.eos_token_id = None | |
| component_config.bos_token_id = None | |
| component_config.pad_token_id = None | |
| return component_config | |
| encoder_config = parse_nmt_config_by_component(config, "encoder", args) | |
| decoder_config = parse_nmt_config_by_component(config, "decoder", args) | |
| return encoder_config, decoder_config | |
| def convert_nmt_weights_to_tllm_safetensors(config, component, params, | |
| sin_pos_embedding): | |
| weights = {} | |
| mapping = config.mapping | |
| hidden_size = config.hidden_size | |
| convert_weight_to_dtype(params, config.dtype) | |
| ffn_hidden_size = config.intermediate_size | |
| vocab_size = config.vocab_size | |
| hf_param_prefix = f'models.0.{component}' | |
| trtllm_layer_name = f'{component}_layers' | |
| trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention' | |
| trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm' | |
| hidden_layer_name_split = { | |
| 'self_attn.out_proj.weight': { | |
| "name": f'{trtllm_attn_layer_name}.dense.weight', | |
| "shape": (hidden_size, hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| 'fc1.weight': { | |
| "name": 'mlp.fc.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| 'fc1.bias': { | |
| "name": 'mlp.fc.bias', | |
| "shape": (ffn_hidden_size // mapping.tp_size), | |
| "split_dim": 0 | |
| }, | |
| 'fc2.weight': { | |
| "name": 'mlp.proj.weight', | |
| "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| } | |
| hidden_layer_name_no_split = { | |
| 'self_attn.out_proj.bias': { | |
| "name": f'{trtllm_attn_layer_name}.dense.bias', | |
| "shape": (hidden_size) | |
| }, | |
| 'self_attn_layer_norm.weight': { | |
| "name": f'{trtllm_attn_layernorm_name}.weight', | |
| "shape": None | |
| }, | |
| 'self_attn_layer_norm.bias': { | |
| "name": f'{trtllm_attn_layernorm_name}.bias', | |
| "shape": None | |
| }, | |
| 'fc2.bias': { | |
| "name": 'mlp.proj.bias', | |
| "shape": (hidden_size) | |
| }, | |
| 'final_layer_norm.weight': { | |
| "name": 'mlp_layernorm.weight', | |
| "shape": None | |
| }, | |
| 'final_layer_norm.bias': { | |
| "name": 'mlp_layernorm.bias', | |
| "shape": None | |
| }, | |
| } | |
| if component == "decoder": | |
| hidden_layer_name_split.update({ | |
| 'encoder_attn.out_proj.weight': { | |
| "name": 'cross_attention.dense.weight', | |
| "shape": (hidden_size, hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| }) | |
| hidden_layer_name_no_split.update({ | |
| 'encoder_attn.out_proj.bias': { | |
| "name": 'cross_attention.dense.bias', | |
| "shape": (hidden_size) | |
| }, | |
| 'encoder_attn_layer_norm.weight': { | |
| "name": 'cross_attention_layernorm.weight', | |
| "shape": None, | |
| }, | |
| 'encoder_attn_layer_norm.bias': { | |
| "name": 'cross_attention_layernorm.bias', | |
| "shape": None | |
| }, | |
| }) | |
| def get_attn_module_name(component, layer, attn_type): | |
| return f'models.0.{component}.layers.{int(layer)}.{attn_type}' | |
| weights["embedding.vocab_embedding.weight"] = reshape( | |
| params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), | |
| (vocab_size, -1)) | |
| weights["embedding.position_embedding.weight"] = reshape( | |
| sin_pos_embedding, (config.max_position_embeddings, hidden_size)) | |
| num_layers = config.num_hidden_layers | |
| layers_range = mapping.pp_layers(num_layers) | |
| for layer_idx in layers_range: | |
| local_layer_idx = layer_idx - layers_range[0] | |
| hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}' | |
| trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' | |
| for hf_weight_name, weight_info in hidden_layer_name_split.items(): | |
| weights[ | |
| f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape( | |
| split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=weight_info["split_dim"]), weight_info["shape"]) | |
| for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): | |
| trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}' | |
| hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}' | |
| weights[trtllm_layer_fullname] = reshape( | |
| params[hf_layer_fullname].clone(), shape=weight_info["shape"]) | |
| self_attn_module_name = get_attn_module_name(component, layer_idx, | |
| 'self_attn') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, self_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', | |
| mapping.tp_size, mapping.tp_rank, config.model_type, | |
| (hidden_size * 3 // mapping.tp_size, hidden_size), | |
| (hidden_size * 3 // mapping.tp_size))) | |
| if component == 'decoder': | |
| cross_attn_module_name = get_attn_module_name( | |
| component, layer_idx, 'encoder_attn') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, cross_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.cross_attention', | |
| mapping.tp_size, mapping.tp_rank, config.model_type, | |
| (hidden_size * 3 // mapping.tp_size, hidden_size), | |
| (hidden_size * 3 // mapping.tp_size))) | |
| if component == 'decoder': | |
| weights['lm_head.weight'] = reshape( | |
| split(params[f'{hf_param_prefix}.output_projection.weight'], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) | |
| if config.has_model_final_layernorm: | |
| weights['final_layernorm.weight'] = params[ | |
| f'{hf_param_prefix}.layer_norm.weight'].clone() | |
| weights['final_layernorm.bias'] = params[ | |
| f'{hf_param_prefix}.layer_norm.bias'].clone() | |
| return weights | |
| def parse_bart_config(args, hf_model): | |
| config = configparser.ConfigParser() | |
| config['decoder'] = dict() | |
| for key, val in hf_model.model.decoder.config.to_dict().items(): | |
| config["decoder"][key] = f"{val}" | |
| config["decoder"]["q_scaling"] = '1' | |
| config["decoder"]["rescale_before_lm_head"] = str(False) | |
| config['decoder']['has_model_final_layernorm'] = str( | |
| args.nougat or isinstance(hf_model, MBartForConditionalGeneration)) | |
| if args.nougat: | |
| # These flags are true for mbart decoders, but missing in HF config | |
| config['decoder']['normalize_before'] = str(True) | |
| config['decoder']['normalize_embeddings'] = str(True) | |
| config['encoder'] = dict() | |
| # Init few encoder configs, needed by build, from decoder config | |
| encoder_config_keys = [ | |
| "encoder_ffn_dim", "encoder_layers", "encoder_attention_heads", | |
| "encoder_layerdrop", "d_model" | |
| ] | |
| for key in encoder_config_keys: | |
| config['encoder'][key] = config['decoder'][key] | |
| else: | |
| config['encoder'] = dict() | |
| for key, val in hf_model.model.encoder.config.to_dict().items(): | |
| config["encoder"][key] = f"{val}" | |
| config["encoder"]["q_scaling"] = '1' | |
| # mBART has final layernorm, BART does not | |
| config['encoder']['has_model_final_layernorm'] = str( | |
| isinstance(hf_model, MBartForConditionalGeneration)) | |
| config["structure"] = dict() | |
| config["structure"]["t5_with_bias"] = "true" | |
| config["structure"]["use_gated_activation"] = "false" | |
| config["structure"]["position_embedding_type"] = "learned_absolute" | |
| config["structure"]["model_type"] = args.model_type | |
| def parse_bart_config_by_component(config, component, args): | |
| assert component in ('encoder', 'decoder'), 'Unsupported component!' | |
| component_config = types.SimpleNamespace() | |
| component_config = copy_args_to_component_config(component_config, args) | |
| component_config.n_layer = config.getint(component, | |
| f'{component}_layers') | |
| component_config.n_head = config.getint(component, | |
| f'{component}_attention_heads') | |
| component_config.hidden_size = config.getint(component, 'd_model') | |
| component_config.head_size = config.getint( | |
| component, | |
| 'd_kv', | |
| fallback=component_config.hidden_size // component_config.n_head) | |
| component_config.ffn_hidden_size = config.getint( | |
| component, f'{component}_ffn_dim') | |
| component_config.vocab_size = config.getint(component, 'vocab_size') | |
| component_config.n_positions = config.getint(component, | |
| 'max_position_embeddings') | |
| component_config.has_position_embedding = config.getboolean( | |
| component, 'has_position_embedding', | |
| fallback=True) # TODO: hardcoded here | |
| component_config.has_token_type_embedding = config.getboolean( | |
| component, 'has_token_type_embedding', fallback=False) | |
| component_config.has_embedding_layernorm = config.getboolean( | |
| component, 'has_embedding_layernorm', fallback=True) | |
| component_config.has_embedding_scale = config.getboolean( | |
| component, 'scale_embedding') | |
| component_config.q_scaling = config.getfloat(component, | |
| 'q_scaling', | |
| fallback=1.0) | |
| component_config.has_attention_qkvo_bias = config.getboolean( | |
| 'structure', 't5_with_bias', fallback=True) | |
| component_config.has_mlp_bias = config.getboolean('structure', | |
| 't5_with_bias', | |
| fallback=True) | |
| component_config.has_model_final_layernorm = config.getboolean( | |
| component, 'has_model_final_layernorm') | |
| component_config.layernorm_eps = config.getfloat(component, | |
| 'layer_norm_epsilon', | |
| fallback=False) | |
| normalize_before = config.getboolean(component, 'normalize_before') | |
| component_config.layernorm_position = layernorm_position_map[ | |
| 'pre_layernorm' if normalize_before else 'post_layernorm'] | |
| component_config.layernorm_type = layernorm_type_map[config.get( | |
| component, 'layernorm_type', fallback='LayerNorm')] | |
| component_config.hidden_act = config.get(component, | |
| 'activation_function') | |
| component_config.gated_act = config.getboolean(component, | |
| 'is_gated_act', | |
| fallback=False) | |
| component_config.mlp_type = mlp_type_map['GatedMLP' if component_config. | |
| gated_act else 'MLP'] | |
| component_config.relative_attention = config.get( | |
| 'structure', 'position_embedding_type') == 'relative' | |
| component_config.num_buckets = config.getint( | |
| component, 'relative_attention_num_buckets', fallback=0) | |
| component_config.max_distance = config.getint( | |
| component, 'relative_attention_max_distance', fallback=0) | |
| component_config.max_lora_rank = config.getint(component, | |
| 'max_lora_rank', | |
| fallback=0) | |
| component_config.lora_target_modules = literal_eval( | |
| config.get(component, 'lora_target_modules', fallback="[]")) | |
| component_config.hf_modules_to_trtllm_modules = literal_eval( | |
| config.get(component, 'hf_modules_to_trtllm_modules', | |
| fallback="{}")) | |
| component_config.trtllm_modules_to_hf_modules = literal_eval( | |
| config.get(component, 'trtllm_modules_to_hf_modules', | |
| fallback="{}")) | |
| component_config.logits_dtype = config.get(component, | |
| 'logits_dtype', | |
| fallback='float32') | |
| component_config.position_embedding_type = config.get( | |
| 'structure', 'position_embedding_type') | |
| if component == 'decoder': | |
| component_config.rescale_before_lm_head = config.getboolean( | |
| component, 'rescale_before_lm_head') | |
| component_config.encoder_hidden_size = config.getint( | |
| 'encoder', 'd_model') | |
| component_config.encoder_num_heads = config.getint( | |
| 'encoder', 'encoder_attention_heads') | |
| component_config.encoder_head_size = config.getint( | |
| 'encoder', | |
| 'd_kv', | |
| fallback=component_config.encoder_hidden_size // | |
| component_config.encoder_num_heads) | |
| # nougat has decoder_start_token_id = None, special handling | |
| decoder_start_token_id = config.get('decoder', | |
| 'decoder_start_token_id') | |
| component_config.decoder_start_token_id = int( | |
| decoder_start_token_id | |
| ) if decoder_start_token_id != "None" else None | |
| component_config.eos_token_id = config.getint( | |
| 'decoder', 'eos_token_id') | |
| component_config.bos_token_id = config.getint( | |
| 'decoder', 'bos_token_id') | |
| component_config.pad_token_id = config.getint( | |
| 'decoder', 'pad_token_id') | |
| return component_config | |
| encoder_config = None | |
| if not args.nougat: | |
| encoder_config = parse_bart_config_by_component(config, "encoder", args) | |
| decoder_config = parse_bart_config_by_component(config, "decoder", args) | |
| return encoder_config, decoder_config | |
| def convert_bart_weights_to_tllm_safetensors(config, component, params): | |
| weights = {} | |
| mapping = config.mapping | |
| hidden_size = config.hidden_size | |
| convert_weight_to_dtype(params, config.dtype) | |
| ffn_hidden_size = config.intermediate_size | |
| vocab_size = config.vocab_size | |
| hf_param_prefix = f'model.{component}' | |
| trtllm_layer_name = f'{component}_layers' | |
| trtllm_attn_layer_name = 'attention' if component == 'encoder' else 'self_attention' | |
| trtllm_attn_layernorm_name = 'self_attention_layernorm' if component == 'decoder' else 'attention_layernorm' | |
| embedding_layer_names = { | |
| 'embed_tokens.weight': { | |
| "name": 'embedding.vocab_embedding.weight', | |
| "shape": (vocab_size, -1) | |
| }, | |
| 'embed_positions.weight': { | |
| "name": 'embedding.position_embedding.weight', | |
| "shape": (config.max_position_embeddings, hidden_size) | |
| }, | |
| 'layernorm_embedding.weight': { | |
| "name": 'embedding.embedding_layernorm.weight', | |
| "shape": None | |
| }, | |
| 'layernorm_embedding.bias': { | |
| "name": 'embedding.embedding_layernorm.bias', | |
| "shape": None | |
| }, | |
| } | |
| hidden_layer_name_split = { | |
| 'self_attn.out_proj.weight': { | |
| "name": f'{trtllm_attn_layer_name}.dense.weight', | |
| "shape": (hidden_size, hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| 'fc1.weight': { | |
| "name": 'mlp.fc.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| 'fc1.bias': { | |
| "name": 'mlp.fc.bias', | |
| "shape": (ffn_hidden_size // mapping.tp_size), | |
| "split_dim": 0 | |
| }, | |
| 'fc2.weight': { | |
| "name": 'mlp.proj.weight', | |
| "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| } | |
| hidden_layer_name_no_split = { | |
| 'self_attn.out_proj.bias': { | |
| "name": f'{trtllm_attn_layer_name}.dense.bias', | |
| "shape": (hidden_size) | |
| }, | |
| 'self_attn_layer_norm.weight': { | |
| "name": f'{trtllm_attn_layernorm_name}.weight', | |
| "shape": None | |
| }, | |
| 'self_attn_layer_norm.bias': { | |
| "name": f'{trtllm_attn_layernorm_name}.bias', | |
| "shape": None | |
| }, | |
| 'fc2.bias': { | |
| "name": 'mlp.proj.bias', | |
| "shape": (hidden_size) | |
| }, | |
| 'final_layer_norm.weight': { | |
| "name": 'mlp_layernorm.weight', | |
| "shape": None | |
| }, | |
| 'final_layer_norm.bias': { | |
| "name": 'mlp_layernorm.bias', | |
| "shape": None | |
| }, | |
| } | |
| if config.model_type == 'mbart': | |
| hidden_layer_name_split['layer_norm.weight'] = { | |
| "name": 'final_layernorm.weight', | |
| "shape": None, | |
| "split_dim": 0 | |
| } | |
| hidden_layer_name_no_split['layer_norm.bias'] = { | |
| "name": 'final_layernorm.bias', | |
| "shape": None, | |
| "split_dim": 0 | |
| } | |
| if component == "decoder": | |
| hidden_layer_name_split.update({ | |
| 'encoder_attn.out_proj.weight': { | |
| "name": 'cross_attention.dense.weight', | |
| "shape": (hidden_size, hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| } | |
| }) | |
| hidden_layer_name_no_split.update({ | |
| 'encoder_attn.out_proj.bias': { | |
| "name": 'cross_attention.dense.bias', | |
| "shape": (hidden_size) | |
| }, | |
| 'encoder_attn_layer_norm.weight': { | |
| "name": 'cross_attention_layernorm.weight', | |
| "shape": None | |
| }, | |
| 'encoder_attn_layer_norm.bias': { | |
| "name": 'cross_attention_layernorm.bias', | |
| "shape": None | |
| }, | |
| }) | |
| def get_attn_module_name(component, layer, attn_type): | |
| return f'model.{component}.layers.{int(layer)}.{attn_type}' | |
| for hf_weight_name, weight_info in embedding_layer_names.items(): | |
| if 'position' in hf_weight_name: | |
| weights[weight_info["name"]] = params[ | |
| f'{hf_param_prefix}.{hf_weight_name}'][2:].clone() | |
| else: | |
| weights[weight_info["name"]] = params[ | |
| f'{hf_param_prefix}.{hf_weight_name}'].clone() | |
| weights[weight_info["name"]] = reshape(weights[weight_info["name"]], | |
| weight_info["shape"]) | |
| num_layers = config.num_hidden_layers | |
| layers_range = mapping.pp_layers(num_layers) | |
| for layer_idx in layers_range: | |
| local_layer_idx = layer_idx - layers_range[0] | |
| hf_layer_name_prefix = f'{hf_param_prefix}.layers.{layer_idx}' | |
| trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' | |
| for hf_weight_name, weight_info in hidden_layer_name_split.items(): | |
| weights[ | |
| f'{trtllm_layer_name_prefix}.{weight_info["name"]}'] = reshape( | |
| split(params[f'{hf_layer_name_prefix}.{hf_weight_name}'], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=weight_info["split_dim"]), weight_info["shape"]) | |
| for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): | |
| trtllm_layer_fullname = f'{trtllm_layer_name_prefix}.{weight_info["name"]}' | |
| hf_layer_fullname = f'{hf_layer_name_prefix}.{hf_weight_name}' | |
| weights[trtllm_layer_fullname] = reshape( | |
| params[hf_layer_fullname].clone(), shape=weight_info["shape"]) | |
| self_attn_module_name = get_attn_module_name(component, layer_idx, | |
| 'self_attn') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, self_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', | |
| mapping.tp_size, mapping.tp_rank, config.model_type, | |
| (hidden_size * 3 // mapping.tp_size, hidden_size), | |
| (hidden_size * 3 // mapping.tp_size))) | |
| if component == 'decoder': | |
| cross_attn_module_name = get_attn_module_name( | |
| component, layer_idx, 'encoder_attn') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, cross_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.cross_attention', | |
| mapping.tp_size, mapping.tp_rank, config.model_type, | |
| (hidden_size * 3 // mapping.tp_size, hidden_size), | |
| (hidden_size * 3 // mapping.tp_size))) | |
| if component == 'decoder': | |
| weights['lm_head.weight'] = reshape( | |
| split(params['lm_head.weight'], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) | |
| if config.has_model_final_layernorm: | |
| weights['final_layernorm.weight'] = params[ | |
| f'{hf_param_prefix}.layer_norm.weight'].clone() | |
| weights['final_layernorm.bias'] = params[ | |
| f'{hf_param_prefix}.layer_norm.bias'].clone() | |
| return weights | |
| def parse_pix2struct_config(args, hf_model): | |
| # manually set q_scaling to offset attention scaling's effect. | |
| # TODO: modify kernels to control whether to disable attention scaling | |
| config = configparser.ConfigParser() | |
| def get_offset_q_scaling(config) -> str: | |
| d_model = config.hidden_size | |
| num_heads = config.num_heads | |
| head_size = d_model / num_heads | |
| scaling = 1 / head_size**.5 | |
| return str(scaling) | |
| config["decoder"] = {} | |
| for key, val in hf_model.decoder.config.to_dict().items(): | |
| config["decoder"][key] = f"{val}" | |
| config["decoder"]["q_scaling"] = get_offset_q_scaling( | |
| hf_model.decoder.config) | |
| config["structure"] = dict() | |
| config["structure"]["pix2struct_with_bias"] = "false" | |
| config["structure"]["use_gated_activation"] = "false" | |
| config["structure"]["position_embedding_type"] = "relative" | |
| config["structure"]["model_type"] = args.model_type | |
| def parse_pix2struct_config_by_component(config, component, args): | |
| if component == 'decoder': | |
| args.n_layer = config.getint(component, 'num_layers') | |
| args.n_head = config.getint(component, 'num_heads') | |
| args.head_size = config.getint(component, 'd_kv') | |
| args.hidden_size = config.getint(component, 'hidden_size') | |
| args.ffn_hidden_size = config.getint(component, 'd_ff') | |
| args.vocab_size = config.getint(component, 'vocab_size') | |
| args.n_positions = config.getint(component, | |
| 'n_positions', | |
| fallback=512) | |
| args.has_position_embedding = config.getboolean( | |
| component, 'has_position_embedding', | |
| fallback=False) # TODO: hardcoded here | |
| args.has_token_type_embedding = config.getboolean( | |
| component, 'has_token_type_embedding', fallback=False) | |
| args.has_embedding_layernorm = config.getboolean( | |
| component, 'has_embedding_layernorm', fallback=False) | |
| args.has_embedding_scale = config.getboolean(component, | |
| 'has_embedding_scale', | |
| fallback=False) | |
| args.q_scaling = config.getfloat(component, | |
| 'q_scaling', | |
| fallback=1.0) | |
| args.has_attention_qkvo_bias = config.getboolean( | |
| component, 'has_attention_qkvo_bias', fallback=False) | |
| args.has_mlp_bias = config.getboolean(component, | |
| 'has_mlp_bias', | |
| fallback=False) | |
| args.has_model_final_layernorm = config.getboolean( | |
| component, 'has_model_final_layernorm', fallback=True) | |
| args.layernorm_eps = config.getfloat(component, | |
| 'layer_norm_epsilon') | |
| args.layernorm_position = layernorm_position_map[config.get( | |
| component, 'layernorm_position', | |
| fallback='pre_layernorm')] # TODO: hardcoded here | |
| args.layernorm_type = layernorm_type_map[config.get( | |
| component, 'layernorm_type', fallback='RmsNorm')] | |
| args.hidden_act = config.get(component, 'dense_act_fn') | |
| args.gated_act = True | |
| args.mlp_type = mlp_type_map['GatedMLP' if args. | |
| gated_act else 'MLP'] | |
| args.has_lm_head_bias = config.getboolean( | |
| component, # TODO: T5 with bias | |
| 'has_lm_head_bias', | |
| fallback=False) | |
| args.relative_attention = config.getboolean(component, | |
| 'relative_attention', | |
| fallback=True) | |
| args.num_buckets = config.getint(component, | |
| 'relative_attention_num_buckets') | |
| args.max_distance = config.getint( | |
| component, 'relative_attention_max_distance') | |
| args.logits_dtype = config.get(component, | |
| 'logits_dtype', | |
| fallback='float32') | |
| args.rescale_before_lm_head = config.getboolean( | |
| component, 'tie_word_embeddings' | |
| ) # default is True (for T5), but False for Flan-T5 | |
| args.encoder_hidden_size = config.getint('decoder', 'hidden_size') | |
| args.encoder_num_heads = config.getint('decoder', 'num_heads') | |
| args.encoder_head_size = config.getint('decoder', 'd_kv') | |
| args.position_embedding_type = config.get( | |
| 'structure', 'position_embedding_type') | |
| args.decoder_start_token_id = config.getint( | |
| 'decoder', 'decoder_start_token_id') | |
| args.eos_token_id = config.getint('decoder', 'eos_token_id') | |
| bos_token_id = config.get('decoder', 'bos_token_id') | |
| # pix2struct does not have bos_token_id | |
| args.bos_token_id = int( | |
| bos_token_id) if bos_token_id != "None" else None | |
| args.pad_token_id = config.getint('decoder', 'pad_token_id') | |
| else: | |
| assert False, 'Unsupported component!' | |
| return args | |
| decoder_args = parse_pix2struct_config_by_component(config, "decoder", args) | |
| return None, decoder_args | |
| def convert_pix2struct_weights_to_tllm_safetensors(config, component, params): | |
| weights = {} | |
| mapping = config.mapping | |
| convert_weight_to_dtype(params, config.dtype) | |
| hidden_size = config.hidden_size | |
| ffn_hidden_size = config.intermediate_size | |
| num_layers = config.num_hidden_layers | |
| n_head = config.num_attention_heads | |
| head_size = config.head_size | |
| attention_hidden_size = n_head * head_size # head size * num_heads not necessarily equals hidden_dim, such as Flan-T5 | |
| hf_param_prefix = f'{component}' | |
| trtllm_layer_name = f'{component}_layers' | |
| trtllm_attn_layer_name = 'self_attention' | |
| trtllm_attn_layernorm_name = 'self_attention_layernorm' | |
| def get_attn_module_name(component, layer, attn_type): | |
| return f'{component}.layer.{int(layer)}.{attn_type}.attention' | |
| weights['embedding.vocab_embedding.weight'] = reshape( | |
| params[f'{hf_param_prefix}.embed_tokens.weight'].clone(), None) | |
| layers_range = mapping.pp_layers(num_layers) | |
| for layer_idx in layers_range: | |
| local_layer_idx = layer_idx - layers_range[0] | |
| trtllm_layer_name_prefix = f'{trtllm_layer_name}.{local_layer_idx}' | |
| hf_layer_name_prefix = f'{hf_param_prefix}.layer.{layer_idx}' | |
| hidden_layer_name_split = { | |
| f'{hf_layer_name_prefix}.self_attention.attention.output.weight': { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.dense.weight', | |
| "shape": | |
| (hidden_size, attention_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| f'{hf_layer_name_prefix}.mlp.DenseReluDense.wo.weight': { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.proj.weight', | |
| "shape": (hidden_size, ffn_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_0.weight': { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.fc.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| } | |
| hidden_layer_name_no_split = { | |
| f'{hf_layer_name_prefix}.self_attention.layer_norm.weight': { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layernorm_name}.weight', | |
| "shape": None | |
| }, | |
| f'{hf_layer_name_prefix}.mlp.layer_norm.weight': { | |
| "name": f'{trtllm_layer_name_prefix}.mlp_layernorm.weight', | |
| "shape": None | |
| }, | |
| } | |
| if config.gated_act: | |
| hidden_layer_name_split.update({ | |
| f'{hf_layer_name_prefix}.mlp.DenseReluDense.wi_1.weight': { | |
| "name": f'{trtllm_layer_name_prefix}.mlp.gate.weight', | |
| "shape": (ffn_hidden_size // mapping.tp_size, hidden_size), | |
| "split_dim": 0 | |
| }, | |
| }) | |
| hidden_layer_name_split.update({ | |
| f'{hf_layer_name_prefix}.encoder_decoder_attention.attention.output.weight': | |
| { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.cross_attention.dense.weight', | |
| "shape": | |
| (hidden_size, attention_hidden_size // mapping.tp_size), | |
| "split_dim": -1 | |
| }, | |
| }) | |
| hidden_layer_name_no_split.update({ | |
| f'{hf_layer_name_prefix}.encoder_decoder_attention.layer_norm.weight': | |
| { | |
| "name": | |
| f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight', | |
| "shape": None | |
| }, | |
| }) | |
| self_attn_module_name = get_attn_module_name( | |
| component, layer_idx, 'encoder_decoder_attention') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, self_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.cross_attention', mapping.tp_size, | |
| mapping.tp_rank, config.model_type, | |
| (attention_hidden_size * 3 // mapping.tp_size, hidden_size), | |
| None)) | |
| self_attn_module_name = get_attn_module_name(component, layer_idx, | |
| 'self_attention') | |
| weights.update( | |
| fuse_qkv_one_layer( | |
| params, self_attn_module_name, | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}', | |
| mapping.tp_size, mapping.tp_rank, config.model_type, | |
| (attention_hidden_size * 3 // mapping.tp_size, hidden_size), | |
| None)) | |
| weights[ | |
| f'{trtllm_layer_name_prefix}.{trtllm_attn_layer_name}.rel_attn_table'] = reshape( | |
| split( | |
| params[ | |
| f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight'] | |
| .T, mapping.tp_size, mapping.tp_rank, 0), | |
| (n_head // mapping.tp_size, config.num_buckets)) | |
| for hf_weight_name, weight_info in hidden_layer_name_split.items(): | |
| if hf_weight_name in params.keys(): | |
| weights[weight_info["name"]] = reshape( | |
| split(params[hf_weight_name], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=weight_info["split_dim"]), weight_info["shape"]) | |
| for hf_weight_name, weight_info in hidden_layer_name_no_split.items(): | |
| if hf_weight_name in params.keys(): | |
| weights[weight_info["name"]] = reshape( | |
| params[hf_weight_name].clone(), shape=weight_info["shape"]) | |
| weights[f'final_layernorm.weight'] = reshape( | |
| params[f'{component}.final_layer_norm.weight'].clone(), None) | |
| weights['lm_head.weight'] = reshape( | |
| split(params[f'{component}.lm_head.weight'], | |
| mapping.tp_size, | |
| mapping.tp_rank, | |
| dim=0), (config.vocab_size // mapping.tp_size, hidden_size)) | |
| if not config.use_implicit_relative_attention: | |
| weights[f'rel_attn_table'] = reshape( | |
| split( | |
| params[ | |
| f'{component}.layer.0.self_attention.attention.relative_attention_bias.weight'] | |
| .T, mapping.tp_size, mapping.tp_rank, 0), | |
| (n_head // mapping.tp_size, config.num_buckets)) | |
| return weights | |
| def get_model(args): | |
| if args.model_type == "t5": | |
| model = T5ForConditionalGeneration.from_pretrained(args.model_dir) | |
| elif args.model_type == "nmt": | |
| from fairseq.models.transformer import TransformerModel | |
| model = TransformerModel.from_pretrained(args.model_dir) | |
| elif args.model_type == "bart": | |
| if args.nougat: | |
| model = VisionEncoderDecoderModel.from_pretrained(args.model_dir) | |
| model = model.get_decoder() | |
| else: | |
| model = AutoModelForSeq2SeqLM.from_pretrained(args.model_dir) | |
| elif args.model_type == "pix2struct": | |
| model = Pix2StructForConditionalGeneration.from_pretrained( | |
| args.model_dir) | |
| elif args.model_type == "blip2": | |
| model = Blip2ForConditionalGeneration.from_pretrained( | |
| args.model_dir).language_model | |
| return model | |
| def convert_checkpoint(args): | |
| model = get_model(args) | |
| saved_dir = Path(args.output_dir) | |
| saved_dir.mkdir(parents=True, exist_ok=True) | |
| encoder_saved_dir = saved_dir / "encoder" | |
| encoder_saved_dir.mkdir(parents=True, exist_ok=True) | |
| decoder_saved_dir = saved_dir / "decoder" | |
| decoder_saved_dir.mkdir(parents=True, exist_ok=True) | |
| world_size = args.tp_size * args.pp_size | |
| kv_cache_quant_algo = None | |
| quant_algo = None | |
| model_type = args.model_type if args.model_type != "blip2" else "t5" | |
| encoder_config, decoder_config = globals()[f'parse_{model_type}_config']( | |
| args, model) | |
| additional_settings = ["gated_act"] | |
| if not args.nougat and args.model_type != "pix2struct": | |
| tllm_encoder_config = { | |
| 'architecture': "EncoderModel", | |
| 'dtype': args.dtype, | |
| 'logits_dtype': encoder_config.logits_dtype, | |
| 'num_hidden_layers': encoder_config.n_layer, | |
| 'num_attention_heads': encoder_config.n_head, | |
| 'hidden_size': encoder_config.hidden_size, | |
| 'norm_epsilon': encoder_config.layernorm_eps, | |
| 'vocab_size': encoder_config.vocab_size, | |
| 'position_embedding_type': encoder_config.position_embedding_type, | |
| 'hidden_act': encoder_config.hidden_act, | |
| 'quantization': { | |
| 'quant_algo': quant_algo, | |
| 'kv_cache_quant_algo': kv_cache_quant_algo, | |
| }, | |
| 'mapping': { | |
| 'world_size': world_size, | |
| 'tp_size': args.tp_size, | |
| 'pp_size': args.pp_size, | |
| }, | |
| 'use_parallel_embedding': args.use_parallel_embedding, | |
| 'embedding_sharding_dim': args.embedding_sharding_dim, | |
| 'max_position_embeddings': encoder_config.n_positions, | |
| 'num_key_value_heads': encoder_config.n_head, | |
| 'head_size': encoder_config.head_size, | |
| 'has_position_embedding': encoder_config.has_position_embedding, | |
| 'layernorm_type': encoder_config.layernorm_type, | |
| 'has_attention_qkvo_bias': encoder_config.has_attention_qkvo_bias, | |
| 'has_mlp_bias': encoder_config.has_mlp_bias, | |
| 'has_model_final_layernorm': | |
| encoder_config.has_model_final_layernorm, | |
| 'has_embedding_layernorm': encoder_config.has_embedding_layernorm, | |
| 'has_embedding_scale': encoder_config.has_embedding_scale, | |
| 'intermediate_size': encoder_config.ffn_hidden_size, | |
| 'q_scaling': encoder_config.q_scaling, | |
| 'layernorm_position': encoder_config.layernorm_position, | |
| 'mlp_type': encoder_config.mlp_type, | |
| 'relative_attention': encoder_config.relative_attention, | |
| 'max_distance': encoder_config.max_distance, | |
| 'num_buckets': encoder_config.num_buckets, | |
| 'model_type': encoder_config.model_type, | |
| } | |
| for additional_setting in additional_settings: | |
| if hasattr(encoder_config, additional_setting): | |
| tllm_encoder_config.update({ | |
| additional_setting: | |
| getattr(encoder_config, additional_setting) | |
| }) | |
| with (encoder_saved_dir / "config.json").open('w') as f: | |
| json.dump(tllm_encoder_config, f, indent=4) | |
| encoder_convert_args = dict(params=model.state_dict(), | |
| component="encoder") | |
| tllm_decoder_config = { | |
| 'architecture': "DecoderModel", | |
| 'dtype': args.dtype, | |
| 'logits_dtype': decoder_config.logits_dtype, | |
| 'num_hidden_layers': decoder_config.n_layer, | |
| 'num_attention_heads': decoder_config.n_head, | |
| 'hidden_size': decoder_config.hidden_size, | |
| 'norm_epsilon': decoder_config.layernorm_eps, | |
| 'vocab_size': decoder_config.vocab_size, | |
| 'position_embedding_type': decoder_config.position_embedding_type, | |
| 'hidden_act': decoder_config.hidden_act, | |
| 'quantization': { | |
| 'quant_algo': quant_algo, | |
| 'kv_cache_quant_algo': kv_cache_quant_algo, | |
| }, | |
| 'mapping': { | |
| 'world_size': world_size, | |
| 'tp_size': args.tp_size, | |
| 'pp_size': args.pp_size, | |
| }, | |
| 'use_parallel_embedding': args.use_parallel_embedding, | |
| 'embedding_sharding_dim': args.embedding_sharding_dim, | |
| 'max_position_embeddings': decoder_config.n_positions, | |
| 'head_size': decoder_config.head_size, | |
| 'has_position_embedding': decoder_config.has_position_embedding, | |
| 'layernorm_type': decoder_config.layernorm_type, | |
| 'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias, | |
| 'has_mlp_bias': decoder_config.has_mlp_bias, | |
| 'has_model_final_layernorm': decoder_config.has_model_final_layernorm, | |
| 'has_embedding_layernorm': decoder_config.has_embedding_layernorm, | |
| 'has_embedding_scale': decoder_config.has_embedding_scale, | |
| 'intermediate_size': decoder_config.ffn_hidden_size, | |
| 'q_scaling': decoder_config.q_scaling, | |
| 'layernorm_position': decoder_config.layernorm_position, | |
| 'mlp_type': decoder_config.mlp_type, | |
| 'relative_attention': decoder_config.relative_attention, | |
| 'max_distance': decoder_config.max_distance, | |
| 'num_buckets': decoder_config.num_buckets, | |
| 'model_type': decoder_config.model_type, | |
| 'rescale_before_lm_head': decoder_config.rescale_before_lm_head, | |
| 'encoder_hidden_size': decoder_config.encoder_hidden_size, | |
| 'encoder_num_heads': decoder_config.encoder_num_heads, | |
| 'encoder_head_size': decoder_config.encoder_head_size, | |
| 'skip_cross_kv': args.skip_cross_kv, | |
| 'use_implicit_relative_attention': args.use_implicit_relative_attention, | |
| 'decoder_start_token_id': decoder_config.decoder_start_token_id, | |
| 'eos_token_id': decoder_config.eos_token_id, | |
| 'bos_token_id': decoder_config.bos_token_id, | |
| 'pad_token_id': decoder_config.pad_token_id, | |
| } | |
| for additional_setting in additional_settings: | |
| if hasattr(decoder_config, additional_setting): | |
| tllm_decoder_config.update({ | |
| additional_setting: | |
| getattr(decoder_config, additional_setting) | |
| }) | |
| with (decoder_saved_dir / "config.json").open('w') as f: | |
| json.dump(tllm_decoder_config, f, indent=4) | |
| decoder_convert_args = dict(params=model.state_dict(), component="decoder") | |
| if args.model_type == "nmt": | |
| fairseq_config = vars(model.cfg.model) # Namespace --> dict | |
| num_embeddings = fairseq_config['max_source_positions'] | |
| embedding_dim = fairseq_config['encoder_embed_dim'] | |
| padding_idx = model.models[0].encoder.embed_tokens.padding_idx # 1 | |
| sin_pos_embedding = model.models[ | |
| 0].encoder.embed_positions.get_embedding( | |
| padding_idx + 1 + num_embeddings, | |
| embedding_dim, | |
| padding_idx=padding_idx) # [2 + num_embeddings, embed_dim] | |
| sin_pos_embedding = sin_pos_embedding[2:, :] # remove offset embeddings | |
| encoder_convert_args["sin_pos_embedding"] = sin_pos_embedding | |
| decoder_convert_args["sin_pos_embedding"] = sin_pos_embedding | |
| if args.workers == 1: | |
| if not args.nougat and args.model_type != "pix2struct": | |
| convert(0, world_size, args, tllm_encoder_config, | |
| encoder_convert_args, encoder_saved_dir) | |
| convert(0, world_size, args, tllm_decoder_config, decoder_convert_args, | |
| decoder_saved_dir) | |
| else: | |
| if args.workers > world_size: | |
| args.workers = world_size | |
| LOGGER.info(f'Convert checkpoint using {args.workers} workers.') | |
| import torch.multiprocessing as mp | |
| if not args.nougat and args.model_type != "pix2struct": | |
| mp.spawn(convert, | |
| nprocs=args.workers, | |
| args=(world_size, args, tllm_encoder_config, | |
| encoder_convert_args, encoder_saved_dir)) | |
| mp.spawn(convert, | |
| nprocs=args.workers, | |
| args=(world_size, args, tllm_decoder_config, | |
| decoder_convert_args, decoder_saved_dir)) | |
| def convert(worker_rank, world_size, args, model_config, convert_args, | |
| saved_dir): | |
| for rank in range(worker_rank, world_size, args.workers): | |
| rank_config = copy.deepcopy(PretrainedConfig.from_dict(model_config)) | |
| rank_config.set_rank(rank) | |
| weights = globals( | |
| )[f'convert_{rank_config.model_type}_weights_to_tllm_safetensors']( | |
| config=rank_config, **convert_args) | |
| safetensors.torch.save_file(weights, | |
| f'{saved_dir}/rank{rank}.safetensors') | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| formatter_class=argparse.RawTextHelpFormatter) | |
| parser.add_argument( | |
| '--model_type', | |
| type=str, | |
| default='t5', | |
| choices=['t5', 'nmt', 'bart', 'pix2struct', 'blip2'], | |
| help= | |
| 'Multimodal type when this script is used for multimodal conversion.') | |
| parser.add_argument('--tp_size', | |
| type=int, | |
| default=1, | |
| help='N-way tensor parallelism size') | |
| parser.add_argument('--pp_size', | |
| type=int, | |
| default=1, | |
| help='N-way pipeline parallelism size') | |
| parser.add_argument("--model_dir", | |
| "-i", | |
| type=str, | |
| help="Path to the framework checkpoint file", | |
| required=True) | |
| parser.add_argument("--output_dir", | |
| "-o", | |
| type=str, | |
| help="Path to the converted TRT-LLM model weight file", | |
| required=True) | |
| parser.add_argument( | |
| "--workers", | |
| type=int, | |
| help="How many workers to spawn for conversion (default: 4)", | |
| default=4) | |
| parser.add_argument("--nougat", | |
| action="store_true", | |
| help="Model which uses vision encoder + mbart decoder") | |
| parser.add_argument("--verbose", | |
| action="store_true", | |
| help="Provide verbose messages") | |
| parser.add_argument( | |
| '--use_parallel_embedding', | |
| action="store_true", | |
| default=False, | |
| help= | |
| 'By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled' | |
| ) | |
| parser.add_argument( | |
| '--embedding_sharding_dim', | |
| type=int, | |
| default=0, | |
| choices=[0, 1], | |
| help= | |
| 'By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). ' | |
| 'To shard it along hidden dimension, set embedding_sharding_dim=1' | |
| 'Note: embedding sharding is only enabled when embedding_sharding_dim = 0' | |
| ) | |
| parser.add_argument( | |
| '--use_weight_only', | |
| default=False, | |
| action="store_true", | |
| help='Quantize weights for the various GEMMs to INT4/INT8.' | |
| 'See --weight_only_precision to set the precision') | |
| parser.add_argument( | |
| '--weight_only_precision', | |
| const='int8', | |
| type=str, | |
| nargs='?', | |
| default='int8', | |
| choices=['int8', 'int4'], | |
| help= | |
| 'Define the precision for the weights when using weight-only quantization.' | |
| 'You must also use --use_weight_only for that argument to have an impact.' | |
| ) | |
| parser.add_argument( | |
| '--dtype', | |
| type=str, | |
| default='float16', | |
| choices=['float16', 'float32', 'bfloat16'], | |
| help= | |
| 'Target inference dtype. Weights and Computation will be in this dtype, no matter what original dtype the weight checkpoint has.' | |
| ) | |
| parser.add_argument( | |
| '--skip_cross_kv', | |
| action='store_true', | |
| help= | |
| 'Skip redundant cross qkv computation by using TensorRT IfConditional switch (experimental).' | |
| ) | |
| parser.add_argument( | |
| '--use_implicit_relative_attention', | |
| action='store_true', | |
| help= | |
| 'Compute relative attention bias on the fly instead of pre-compute a relative attention bias table.' | |
| ) | |
| args = parser.parse_args() | |
| log_format = "%(asctime)s %(name)s [%(levelname)s] %(message)s" | |
| logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO, | |
| format=log_format) | |
| LOGGER.info("\n=============== Argument ===============") | |
| for key in vars(args): | |
| LOGGER.info(f"{key}: {vars(args)[key]}") | |
| LOGGER.info("========================================") | |
| start_time = datetime.now() | |
| convert_checkpoint(args) | |
| stop_time = datetime.now() | |
| run_time = (stop_time - start_time) | |
| LOGGER.info("Spend {} (h:m:s) to convert the model".format(run_time)) | |