Spaces:
Build error
Build error
| # Copyright 2024 the LlamaFactory team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import json | |
| import logging | |
| import os | |
| import signal | |
| import sys | |
| import time | |
| from concurrent.futures import ThreadPoolExecutor | |
| from datetime import timedelta | |
| from typing import TYPE_CHECKING, Any, Dict, Optional | |
| import torch | |
| import transformers | |
| from peft import PeftModel | |
| from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback | |
| from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length | |
| from transformers.utils import ( | |
| SAFE_WEIGHTS_NAME, | |
| WEIGHTS_NAME, | |
| is_safetensors_available, | |
| ) | |
| from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME | |
| from ..extras.logging import LoggerHandler, get_logger | |
| if is_safetensors_available(): | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| if TYPE_CHECKING: | |
| from transformers import TrainerControl, TrainerState, TrainingArguments | |
| from trl import AutoModelForCausalLMWithValueHead | |
| logger = get_logger(__name__) | |
| def fix_valuehead_checkpoint( | |
| model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool | |
| ) -> None: | |
| r""" | |
| The model is already unwrapped. | |
| There are three cases: | |
| 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} | |
| 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} | |
| 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} | |
| We assume `stage3_gather_16bit_weights_on_model_save=true`. | |
| """ | |
| if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): | |
| return | |
| if safe_serialization: | |
| path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) | |
| with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: | |
| state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} | |
| else: | |
| path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) | |
| state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") | |
| decoder_state_dict = {} | |
| v_head_state_dict = {} | |
| for name, param in state_dict.items(): | |
| if name.startswith("v_head."): | |
| v_head_state_dict[name] = param | |
| else: | |
| decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param | |
| model.pretrained_model.save_pretrained( | |
| output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization | |
| ) | |
| if safe_serialization: | |
| save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) | |
| else: | |
| torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) | |
| os.remove(path_to_checkpoint) | |
| logger.info("Value head model saved at: {}".format(output_dir)) | |
| class FixValueHeadModelCallback(TrainerCallback): | |
| def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called after a checkpoint save. | |
| """ | |
| if args.should_save: | |
| fix_valuehead_checkpoint( | |
| model=kwargs.pop("model"), | |
| output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), | |
| safe_serialization=args.save_safetensors, | |
| ) | |
| class SaveProcessorCallback(TrainerCallback): | |
| def __init__(self, processor: "ProcessorMixin") -> None: | |
| r""" | |
| Initializes a callback for saving the processor. | |
| """ | |
| self.processor = processor | |
| def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the end of training. | |
| """ | |
| if args.should_save: | |
| getattr(self.processor, "image_processor").save_pretrained(args.output_dir) | |
| class PissaConvertCallback(TrainerCallback): | |
| r""" | |
| Initializes a callback for converting the PiSSA adapter to a normal one. | |
| """ | |
| def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the beginning of training. | |
| """ | |
| if args.should_save: | |
| model = kwargs.pop("model") | |
| pissa_init_dir = os.path.join(args.output_dir, "pissa_init") | |
| logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir)) | |
| if isinstance(model, PeftModel): | |
| init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") | |
| setattr(model.peft_config["default"], "init_lora_weights", True) | |
| model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) | |
| setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) | |
| def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the end of training. | |
| """ | |
| if args.should_save: | |
| model = kwargs.pop("model") | |
| pissa_init_dir = os.path.join(args.output_dir, "pissa_init") | |
| pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") | |
| pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") | |
| logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) | |
| # 1. save a pissa backup with init_lora_weights: True | |
| # 2. save a converted lora with init_lora_weights: pissa | |
| # 3. load the pissa backup with init_lora_weights: True | |
| # 4. delete the initial adapter and change init_lora_weights to pissa | |
| if isinstance(model, PeftModel): | |
| init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") | |
| setattr(model.peft_config["default"], "init_lora_weights", True) | |
| model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) | |
| setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) | |
| model.save_pretrained( | |
| pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir | |
| ) | |
| model.load_adapter(pissa_backup_dir, "default", is_trainable=True) | |
| model.set_adapter("default") | |
| model.delete_adapter("pissa_init") | |
| setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) | |
| class LogCallback(TrainerCallback): | |
| def __init__(self) -> None: | |
| r""" | |
| Initializes a callback for logging training and evaluation status. | |
| """ | |
| """ Progress """ | |
| self.start_time = 0 | |
| self.cur_steps = 0 | |
| self.max_steps = 0 | |
| self.elapsed_time = "" | |
| self.remaining_time = "" | |
| self.thread_pool: Optional["ThreadPoolExecutor"] = None | |
| """ Status """ | |
| self.aborted = False | |
| self.do_train = False | |
| """ Web UI """ | |
| self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] | |
| if self.webui_mode: | |
| signal.signal(signal.SIGABRT, self._set_abort) | |
| self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) | |
| logging.root.addHandler(self.logger_handler) | |
| transformers.logging.add_handler(self.logger_handler) | |
| def _set_abort(self, signum, frame) -> None: | |
| self.aborted = True | |
| def _reset(self, max_steps: int = 0) -> None: | |
| self.start_time = time.time() | |
| self.cur_steps = 0 | |
| self.max_steps = max_steps | |
| self.elapsed_time = "" | |
| self.remaining_time = "" | |
| def _timing(self, cur_steps: int) -> None: | |
| cur_time = time.time() | |
| elapsed_time = cur_time - self.start_time | |
| avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 | |
| remaining_time = (self.max_steps - cur_steps) * avg_time_per_step | |
| self.cur_steps = cur_steps | |
| self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) | |
| self.remaining_time = str(timedelta(seconds=int(remaining_time))) | |
| def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: | |
| with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: | |
| f.write(json.dumps(logs) + "\n") | |
| def _create_thread_pool(self, output_dir: str) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| self.thread_pool = ThreadPoolExecutor(max_workers=1) | |
| def _close_thread_pool(self) -> None: | |
| if self.thread_pool is not None: | |
| self.thread_pool.shutdown(wait=True) | |
| self.thread_pool = None | |
| def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the end of the initialization of the `Trainer`. | |
| """ | |
| if ( | |
| args.should_save | |
| and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) | |
| and args.overwrite_output_dir | |
| ): | |
| logger.warning("Previous trainer log in this folder will be deleted.") | |
| os.remove(os.path.join(args.output_dir, TRAINER_LOG)) | |
| def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the beginning of training. | |
| """ | |
| if args.should_save: | |
| self.do_train = True | |
| self._reset(max_steps=state.max_steps) | |
| self._create_thread_pool(output_dir=args.output_dir) | |
| def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the end of training. | |
| """ | |
| self._close_thread_pool() | |
| def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the end of an substep during gradient accumulation. | |
| """ | |
| if self.aborted: | |
| control.should_epoch_stop = True | |
| control.should_training_stop = True | |
| def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called at the end of a training step. | |
| """ | |
| if self.aborted: | |
| control.should_epoch_stop = True | |
| control.should_training_stop = True | |
| def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called after an evaluation phase. | |
| """ | |
| if not self.do_train: | |
| self._close_thread_pool() | |
| def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called after a successful prediction. | |
| """ | |
| if not self.do_train: | |
| self._close_thread_pool() | |
| def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): | |
| r""" | |
| Event called after logging the last logs. | |
| """ | |
| if not args.should_save: | |
| return | |
| self._timing(cur_steps=state.global_step) | |
| logs = dict( | |
| current_steps=self.cur_steps, | |
| total_steps=self.max_steps, | |
| loss=state.log_history[-1].get("loss", None), | |
| eval_loss=state.log_history[-1].get("eval_loss", None), | |
| predict_loss=state.log_history[-1].get("predict_loss", None), | |
| reward=state.log_history[-1].get("reward", None), | |
| accuracy=state.log_history[-1].get("rewards/accuracies", None), | |
| learning_rate=state.log_history[-1].get("learning_rate", None), | |
| epoch=state.log_history[-1].get("epoch", None), | |
| percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, | |
| elapsed_time=self.elapsed_time, | |
| remaining_time=self.remaining_time, | |
| throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)), | |
| total_tokens=state.num_input_tokens_seen, | |
| ) | |
| logs = {k: v for k, v in logs.items() if v is not None} | |
| if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): | |
| logger.info( | |
| "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( | |
| logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"] | |
| ) | |
| ) | |
| if self.thread_pool is not None: | |
| self.thread_pool.submit(self._write_log, args.output_dir, logs) | |
| def on_prediction_step( | |
| self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs | |
| ): | |
| r""" | |
| Event called after a prediction step. | |
| """ | |
| if self.do_train: | |
| return | |
| if self.aborted: | |
| sys.exit(0) | |
| if not args.should_save: | |
| return | |
| eval_dataloader = kwargs.pop("eval_dataloader", None) | |
| if has_length(eval_dataloader): | |
| if self.max_steps == 0: | |
| self._reset(max_steps=len(eval_dataloader)) | |
| self._create_thread_pool(output_dir=args.output_dir) | |
| self._timing(cur_steps=self.cur_steps + 1) | |
| if self.cur_steps % 5 == 0 and self.thread_pool is not None: | |
| logs = dict( | |
| current_steps=self.cur_steps, | |
| total_steps=self.max_steps, | |
| percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, | |
| elapsed_time=self.elapsed_time, | |
| remaining_time=self.remaining_time, | |
| ) | |
| self.thread_pool.submit(self._write_log, args.output_dir, logs) | |