|  |  | 
					
						
						|  | """ | 
					
						
						|  | This module converts a transformers LlamaForCausalLM to a brrr model | 
					
						
						|  |  | 
					
						
						|  | Command: | 
					
						
						|  | torchrun  --nproc_per_node=1 convert_trfrs_to_brrr.py \ | 
					
						
						|  | --model_name  mistralai/Mistral-7B-v0.1 \ | 
					
						
						|  | --save_path ./pretrained/Mistral-7B-v0.1 | 
					
						
						|  | """ | 
					
						
						|  | import argparse | 
					
						
						|  | import sys | 
					
						
						|  | from dataclasses import asdict | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import Dict, List | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from brrr.trainer import DistributedTrainer | 
					
						
						|  |  | 
					
						
						|  | sys.path.append(Path(__file__).parent.parent.as_posix()) | 
					
						
						|  | import os | 
					
						
						|  |  | 
					
						
						|  | from nanotron.parallel.parameters import NanotronParameter, sanity_check | 
					
						
						|  | from nanotron.parallel.pipeline_parallel.engine import ( | 
					
						
						|  | AllForwardAllBackwardPipelineEngine, | 
					
						
						|  | ) | 
					
						
						|  | from nanotron.parallel.tensor_parallel.nn import TensorParallelLinearMode | 
					
						
						|  | from transformers import MistralConfig as MistralConfig_trfs, MistralForCausalLM | 
					
						
						|  |  | 
					
						
						|  | import nanotron.distributed as dist | 
					
						
						|  | from nanotron.config import ParallelismArgs, RecomputeGranularity | 
					
						
						|  | from nanotron.parallel.context import ParallelContext | 
					
						
						|  | from nanotron.models import build_model | 
					
						
						|  | from nanotron.trainer import mark_tied_parameters | 
					
						
						|  | from nanotron.serialize import save_meta, save_weights, save | 
					
						
						|  |  | 
					
						
						|  | from modeling_mistral import MistralForTraining | 
					
						
						|  | from config_mistral_7b import PARALLELISM as PARALLELISM_BRRR, CONFIG as CONFIG_BRRR | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_args(): | 
					
						
						|  | parser = argparse.ArgumentParser(description="Convert transformers weights to brrr weights") | 
					
						
						|  | parser.add_argument("--model_name", type=str, default="mistralai/Mistral-7B-v0.1") | 
					
						
						|  | parser.add_argument("--save_path", type=str, default="pretrained/Mistral-7B-v0.1") | 
					
						
						|  | parser.add_argument("--dp", type=int, default=1) | 
					
						
						|  | parser.add_argument("--pp", type=int, default=1) | 
					
						
						|  | parser.add_argument("--tp", type=int, default=1) | 
					
						
						|  | return parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def permute_for_rotary(tensor, num_heads, per_head_hidden_size, hidden_size): | 
					
						
						|  | return ( | 
					
						
						|  | tensor.view(num_heads, 2, per_head_hidden_size // 2, hidden_size) | 
					
						
						|  | .transpose(1, 2) | 
					
						
						|  | .contiguous() | 
					
						
						|  | .view(num_heads * per_head_hidden_size, hidden_size) | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_transformers_weight( | 
					
						
						|  | name: str, ref_module_state_dict: Dict[str, torch.Tensor], ref_module: MistralForCausalLM, get_grad: bool = False | 
					
						
						|  | ) -> torch.Tensor: | 
					
						
						|  | """From our brrr implementation, we get the equivalent tensor in transformers implementation""" | 
					
						
						|  | config = ref_module.config | 
					
						
						|  | brrr_prefix = "model." | 
					
						
						|  | assert name.startswith(brrr_prefix) | 
					
						
						|  | name = name[len(brrr_prefix) :] | 
					
						
						|  |  | 
					
						
						|  | path = name.split(".") | 
					
						
						|  | path.remove("pp_block") | 
					
						
						|  | name = ".".join(path) | 
					
						
						|  |  | 
					
						
						|  | if get_grad is False: | 
					
						
						|  |  | 
					
						
						|  | def get_tensor(path: str): | 
					
						
						|  | return ref_module_state_dict[path] | 
					
						
						|  |  | 
					
						
						|  | def get_tensors(path: List[str]): | 
					
						
						|  | return [get_tensor(p) for p in path] | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | def get_tensor(path: str): | 
					
						
						|  | weight = ref_module.get_parameter(path) | 
					
						
						|  | return weight.grad | 
					
						
						|  |  | 
					
						
						|  | def get_tensors(path: List[str]): | 
					
						
						|  | return [get_tensor(p) for p in path] | 
					
						
						|  |  | 
					
						
						|  | if name == "token_position_embeddings.token_embedding.weight": | 
					
						
						|  | return get_tensor("model.embed_tokens.weight") | 
					
						
						|  |  | 
					
						
						|  | elif name == "lm_head.weight": | 
					
						
						|  |  | 
					
						
						|  | return get_tensor("lm_head.weight") | 
					
						
						|  |  | 
					
						
						|  | elif name == "final_layer_norm.weight": | 
					
						
						|  | return get_tensor("model.norm.weight") | 
					
						
						|  |  | 
					
						
						|  | if path[0] == "decoder": | 
					
						
						|  | transformer_path = ["model"] + ["layers"] + [path[1]] | 
					
						
						|  |  | 
					
						
						|  | if path[2] == "attn": | 
					
						
						|  | path[2] = "self_attn" | 
					
						
						|  |  | 
					
						
						|  | if path[2] == "ff": | 
					
						
						|  | path[2] = "mlp" | 
					
						
						|  |  | 
					
						
						|  | if path[3] == "qkv_proj": | 
					
						
						|  | proj_names = ["q_proj", "k_proj", "v_proj"] | 
					
						
						|  | tensor_list = get_tensors( | 
					
						
						|  | [".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) for proj_name in proj_names] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | per_head_hidden_size = config.hidden_size // config.num_attention_heads | 
					
						
						|  |  | 
					
						
						|  | print(f"Permuting q {tensor_list[0].shape}") | 
					
						
						|  | tensor_list[0] = permute_for_rotary( | 
					
						
						|  | tensor=tensor_list[0], | 
					
						
						|  | num_heads=config.num_attention_heads, | 
					
						
						|  | per_head_hidden_size=per_head_hidden_size, | 
					
						
						|  | hidden_size=config.hidden_size, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | print(f"Permuting k {tensor_list[1].shape}") | 
					
						
						|  | tensor_list[1] = permute_for_rotary( | 
					
						
						|  | tensor=tensor_list[1], | 
					
						
						|  | num_heads=config.num_key_value_heads, | 
					
						
						|  | per_head_hidden_size=per_head_hidden_size, | 
					
						
						|  | hidden_size=config.hidden_size, | 
					
						
						|  | ) | 
					
						
						|  | return torch.cat(tensor_list, dim=0) | 
					
						
						|  |  | 
					
						
						|  | if path[3] == "gate_up_proj": | 
					
						
						|  | tensor_list = get_tensors( | 
					
						
						|  | [ | 
					
						
						|  | ".".join(transformer_path + path[2:3] + [proj_name] + path[4:]) | 
					
						
						|  | for proj_name in ["gate_proj", "up_proj"] | 
					
						
						|  | ] | 
					
						
						|  | ) | 
					
						
						|  | return torch.cat(tensor_list, dim=0) | 
					
						
						|  |  | 
					
						
						|  | return get_tensor(".".join(transformer_path + path[2:])) | 
					
						
						|  |  | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Couldn't find transformer equivalent of {name}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_trfrs_to_brrr(dp, pp, tp, model_name="huggyllama/llama-7b", save_path="pretrained/llama-7b"): | 
					
						
						|  |  | 
					
						
						|  | save_path = Path(save_path) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | parallel_config = PARALLELISM_BRRR | 
					
						
						|  |  | 
					
						
						|  | parallel_config.dp = dp | 
					
						
						|  | parallel_config.pp = pp | 
					
						
						|  | parallel_config.tp = tp | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | parallel_context = ParallelContext( | 
					
						
						|  | data_parallel_size=parallel_config.dp, | 
					
						
						|  | pipeline_parallel_size=parallel_config.pp, | 
					
						
						|  | tensor_parallel_size=parallel_config.tp, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | dtype = torch.bfloat16 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_config_brrr = CONFIG_BRRR.model.model_config | 
					
						
						|  |  | 
					
						
						|  | model = build_model( | 
					
						
						|  | model_builder=lambda: MistralForTraining( | 
					
						
						|  | config=model_config_brrr, | 
					
						
						|  | parallel_context=parallel_context, | 
					
						
						|  | parallel_config=parallel_config, | 
					
						
						|  | random_states=None, | 
					
						
						|  | ), | 
					
						
						|  | dtype=dtype, | 
					
						
						|  | parallel_context=parallel_context, | 
					
						
						|  | device=torch.device("cpu"), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | device_map = {} | 
					
						
						|  | current_pp_rank = dist.get_rank(group=parallel_context.pp_pg) | 
					
						
						|  | device_map["model.embed_tokens"] = ( | 
					
						
						|  | model.model.token_position_embeddings.rank | 
					
						
						|  | if current_pp_rank == model.model.token_position_embeddings.rank | 
					
						
						|  | else "meta" | 
					
						
						|  | ) | 
					
						
						|  | for i in range(model_config_brrr.num_hidden_layers): | 
					
						
						|  | device_map[f"model.layers.{i}"] = ( | 
					
						
						|  | model.model.decoder[i].rank if current_pp_rank == model.model.decoder[i].rank else "meta" | 
					
						
						|  | ) | 
					
						
						|  | device_map["model.norm"] = ( | 
					
						
						|  | model.model.final_layer_norm.rank if current_pp_rank == model.model.final_layer_norm.rank else "meta" | 
					
						
						|  | ) | 
					
						
						|  | device_map["lm_head"] = model.model.lm_head.rank if current_pp_rank == model.model.lm_head.rank else "meta" | 
					
						
						|  | model_ref = MistralForCausalLM.from_pretrained(model_name, torch_dtype=dtype, device_map=device_map) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ref_state_dict = model_ref.state_dict() | 
					
						
						|  | for name, param in model.named_parameters(): | 
					
						
						|  | print(f"Syncing {name}") | 
					
						
						|  | ref_param = get_transformers_weight(name=name, ref_module_state_dict=ref_state_dict, ref_module=model_ref) | 
					
						
						|  |  | 
					
						
						|  | param_is_tp_sharded = ( | 
					
						
						|  | isinstance(param, NanotronParameter) | 
					
						
						|  | and param.is_sharded | 
					
						
						|  | and parallel_context.world_ranks_to_pg[param.get_sharded_info().global_ranks] == parallel_context.tp_pg | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | if param_is_tp_sharded: | 
					
						
						|  | sharded_info = param.get_sharded_info() | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | for local_global_slices_pair in sharded_info.local_global_slices_pairs: | 
					
						
						|  | local_slices = local_global_slices_pair.local_slices | 
					
						
						|  | global_slices = local_global_slices_pair.global_slices | 
					
						
						|  | param[local_slices].copy_(ref_param[global_slices]) | 
					
						
						|  | else: | 
					
						
						|  | assert ( | 
					
						
						|  | ref_param.shape == param.shape | 
					
						
						|  | ), f"Parameter shape don't match for {name}\n{ref_param.shape} != {param.shape}" | 
					
						
						|  |  | 
					
						
						|  | with torch.no_grad(): | 
					
						
						|  | param.copy_(ref_param) | 
					
						
						|  | ref_param = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mark_tied_parameters(model=model, parallel_context=parallel_context, parallel_config=parallel_config) | 
					
						
						|  |  | 
					
						
						|  | sanity_check(root_module=model) | 
					
						
						|  |  | 
					
						
						|  | checkpoint_metadata = { | 
					
						
						|  | "last_train_step": 0, | 
					
						
						|  | "consumed_train_samples": 0, | 
					
						
						|  | } | 
					
						
						|  | save(config=CONFIG_BRRR, model=model, optimizer=None, lr_scheduler=None, parallel_context=parallel_context, root_folder=save_path, | 
					
						
						|  | should_save_optimizer=False, should_save_lr_scheduler=False, checkpoint_metadata=checkpoint_metadata, | 
					
						
						|  | sanity_checks=False) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if dist.get_rank(parallel_context.world_pg) == 0: | 
					
						
						|  | print(save_path) | 
					
						
						|  | import json | 
					
						
						|  |  | 
					
						
						|  | with open(save_path / "model_config.json", mode="w") as fo: | 
					
						
						|  | fo.write(json.dumps(asdict(CONFIG_BRRR.model.model_config), indent=4)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(): | 
					
						
						|  | args = get_args() | 
					
						
						|  | convert_trfrs_to_brrr(**vars(args)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | main() | 
					
						
						|  |  |