| 
							 | 
						""" Example python script to generate a YAML config file which can be used to run a training with nanotron. Refer to "examples" section in the `/README.md` for more information. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						Usage: | 
					
					
						
						| 
							 | 
						``` | 
					
					
						
						| 
							 | 
						python config_tiny_mistral.py | 
					
					
						
						| 
							 | 
						``` | 
					
					
						
						| 
							 | 
						""" | 
					
					
						
						| 
							 | 
						import os | 
					
					
						
						| 
							 | 
						from dataclasses import dataclass | 
					
					
						
						| 
							 | 
						from typing import Optional | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from nanotron.config import ( | 
					
					
						
						| 
							 | 
						    CheckpointsArgs, | 
					
					
						
						| 
							 | 
						    Config, | 
					
					
						
						| 
							 | 
						    DataArgs, | 
					
					
						
						| 
							 | 
						    GeneralArgs, | 
					
					
						
						| 
							 | 
						    LoggingArgs, | 
					
					
						
						| 
							 | 
						    LRSchedulerArgs, | 
					
					
						
						| 
							 | 
						    ModelArgs, | 
					
					
						
						| 
							 | 
						    OptimizerArgs, | 
					
					
						
						| 
							 | 
						    ParallelismArgs, | 
					
					
						
						| 
							 | 
						    PretrainDatasetsArgs, | 
					
					
						
						| 
							 | 
						    RandomInit, | 
					
					
						
						| 
							 | 
						    TokenizerArgs, | 
					
					
						
						| 
							 | 
						    TokensArgs, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						from nanotron.logging import human_format | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						from config_mistral import MistralConfig, get_num_params | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						model_config = MistralConfig( | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    attn_pdrop=0.0, | 
					
					
						
						| 
							 | 
						    bos_token_id=1, | 
					
					
						
						| 
							 | 
						    eos_token_id=2, | 
					
					
						
						| 
							 | 
						    hidden_act="silu", | 
					
					
						
						| 
							 | 
						    hidden_size=16, | 
					
					
						
						| 
							 | 
						    initializer_range=0.02, | 
					
					
						
						| 
							 | 
						    intermediate_size=64, | 
					
					
						
						| 
							 | 
						    max_position_embeddings=256, | 
					
					
						
						| 
							 | 
						    num_attention_heads=4, | 
					
					
						
						| 
							 | 
						    num_hidden_layers=2, | 
					
					
						
						| 
							 | 
						    num_key_value_heads=4, | 
					
					
						
						| 
							 | 
						    pretraining_tp=1, | 
					
					
						
						| 
							 | 
						    rms_norm_eps=1e-05, | 
					
					
						
						| 
							 | 
						    rope_theta=10000.0, | 
					
					
						
						| 
							 | 
						    tie_word_embeddings=True, | 
					
					
						
						| 
							 | 
						    use_cache=True, | 
					
					
						
						| 
							 | 
						    vocab_size=256, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						num_params = human_format(get_num_params(model_config)).replace(".", "p") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						print(f"Model has {num_params} parameters") | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						seed = 42 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						learning_rate = LRSchedulerArgs( | 
					
					
						
						| 
							 | 
						    learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						optimizer = OptimizerArgs( | 
					
					
						
						| 
							 | 
						    zero_stage=0, | 
					
					
						
						| 
							 | 
						    weight_decay=0.01, | 
					
					
						
						| 
							 | 
						    clip_grad=1.0, | 
					
					
						
						| 
							 | 
						    accumulate_grad_in_fp32=True, | 
					
					
						
						| 
							 | 
						    adam_eps=1e-08, | 
					
					
						
						| 
							 | 
						    adam_beta1=0.9, | 
					
					
						
						| 
							 | 
						    adam_beta2=0.95, | 
					
					
						
						| 
							 | 
						    torch_adam_is_fused=True, | 
					
					
						
						| 
							 | 
						    learning_rate_scheduler=learning_rate, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						parallelism = ParallelismArgs( | 
					
					
						
						| 
							 | 
						    dp=2, | 
					
					
						
						| 
							 | 
						    pp=2, | 
					
					
						
						| 
							 | 
						    tp=2, | 
					
					
						
						| 
							 | 
						    pp_engine="1f1b", | 
					
					
						
						| 
							 | 
						    tp_mode="REDUCE_SCATTER", | 
					
					
						
						| 
							 | 
						    tp_linear_async_communication=True, | 
					
					
						
						| 
							 | 
						    recompute_granularity="selective", | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						tokens = TokensArgs(sequence_length=32, train_steps=10, micro_batch_size=2, batch_accumulation_per_replica=1) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						dataset = PretrainDatasetsArgs( | 
					
					
						
						| 
							 | 
						    hf_dataset_or_datasets="HuggingFaceH4/testing_alpaca_small", text_column_name="completion" | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						checkpoints_path = os.path.dirname(os.path.dirname(__file__)) + "/checkpoints" | 
					
					
						
						| 
							 | 
						os.makedirs(checkpoints_path, exist_ok=True) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						config = Config( | 
					
					
						
						| 
							 | 
						    general=GeneralArgs(project="debug", run="tiny_mistral", seed=seed, step=0), | 
					
					
						
						| 
							 | 
						    checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), | 
					
					
						
						| 
							 | 
						    parallelism=parallelism, | 
					
					
						
						| 
							 | 
						    model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), | 
					
					
						
						| 
							 | 
						    tokenizer=TokenizerArgs("gpt2"), | 
					
					
						
						| 
							 | 
						    optimizer=optimizer, | 
					
					
						
						| 
							 | 
						    logging=LoggingArgs(), | 
					
					
						
						| 
							 | 
						    tokens=tokens, | 
					
					
						
						| 
							 | 
						    data=DataArgs(dataset=dataset, seed=seed), | 
					
					
						
						| 
							 | 
						    profiler=None, | 
					
					
						
						| 
							 | 
						    lighteval=None, | 
					
					
						
						| 
							 | 
						) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						if __name__ == "__main__": | 
					
					
						
						| 
							 | 
						    file_path = os.path.abspath(__file__) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    file_path = file_path.replace(".py", ".yaml") | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						    config.save_as_yaml(file_path) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						     | 
					
					
						
						| 
							 | 
						
 |