Spaces:
Runtime error
Runtime error
| # coding=utf-8 | |
| # Copyright 2020-present the HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| The Trainer class, to easily train a π€ Transformers from scratch or finetune it on a new task. | |
| """ | |
| import contextlib | |
| import copy | |
| import functools | |
| import glob | |
| import importlib.metadata | |
| import inspect | |
| import json | |
| import math | |
| import os | |
| import random | |
| import re | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| import warnings | |
| from collections.abc import Mapping | |
| from pathlib import Path | |
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union | |
| # Integrations must be imported before ML frameworks: | |
| # isort: off | |
| from transformers.integrations import ( | |
| get_reporting_integration_callbacks, | |
| hp_params, | |
| ) | |
| # isort: on | |
| import huggingface_hub.utils as hf_hub_utils | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from huggingface_hub import ModelCard, create_repo, upload_folder | |
| from packaging import version | |
| from torch import nn | |
| from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler | |
| from transformers import __version__ | |
| from transformers.configuration_utils import PretrainedConfig | |
| from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator | |
| from transformers.debug_utils import DebugOption, DebugUnderflowOverflow | |
| from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor | |
| from transformers.feature_extraction_utils import FeatureExtractionMixin | |
| from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend | |
| from transformers.image_processing_utils import BaseImageProcessor | |
| from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available | |
| from transformers.integrations.tpu import tpu_spmd_dataloader | |
| from transformers.modelcard import TrainingSummary | |
| from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model | |
| from transformers.models.auto.modeling_auto import ( | |
| MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, | |
| MODEL_MAPPING_NAMES, | |
| ) | |
| from transformers.optimization import Adafactor, get_scheduler | |
| from transformers.processing_utils import ProcessorMixin | |
| from transformers.pytorch_utils import ( | |
| ALL_LAYERNORM_LAYERS, | |
| is_torch_greater_or_equal_than_2_3, | |
| ) | |
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | |
| from transformers.trainer_callback import ( | |
| CallbackHandler, | |
| DefaultFlowCallback, | |
| ExportableState, | |
| PrinterCallback, | |
| ProgressCallback, | |
| TrainerCallback, | |
| TrainerControl, | |
| TrainerState, | |
| ) | |
| from transformers.trainer_pt_utils import ( | |
| DistributedTensorGatherer, | |
| EvalLoopContainer, | |
| IterableDatasetShard, | |
| LabelSmoother, | |
| LayerWiseDummyOptimizer, | |
| LengthGroupedSampler, | |
| SequentialDistributedSampler, | |
| distributed_broadcast_scalars, | |
| distributed_concat, | |
| find_batch_size, | |
| get_model_param_count, | |
| get_module_class_from_name, | |
| get_parameter_names, | |
| nested_concat, | |
| nested_detach, | |
| nested_numpify, | |
| nested_xla_mesh_reduce, | |
| reissue_pt_warnings, | |
| remove_dummy_checkpoint, | |
| ) | |
| from transformers.trainer_utils import ( | |
| PREFIX_CHECKPOINT_DIR, | |
| BestRun, | |
| EvalLoopOutput, | |
| EvalPrediction, | |
| HPSearchBackend, | |
| HubStrategy, | |
| PredictionOutput, | |
| RemoveColumnsCollator, | |
| SaveStrategy, | |
| TrainerMemoryTracker, | |
| TrainOutput, | |
| check_target_module_exists, | |
| default_compute_objective, | |
| denumpify_detensorize, | |
| enable_full_determinism, | |
| find_executable_batch_size, | |
| get_last_checkpoint, | |
| has_length, | |
| neftune_post_forward_hook, | |
| number_of_arguments, | |
| seed_worker, | |
| set_seed, | |
| speed_metrics, | |
| ) | |
| from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments | |
| from transformers.utils import ( | |
| ADAPTER_CONFIG_NAME, | |
| ADAPTER_SAFE_WEIGHTS_NAME, | |
| ADAPTER_WEIGHTS_NAME, | |
| CONFIG_NAME, | |
| SAFE_WEIGHTS_INDEX_NAME, | |
| SAFE_WEIGHTS_NAME, | |
| WEIGHTS_INDEX_NAME, | |
| WEIGHTS_NAME, | |
| XLA_FSDPV2_MIN_VERSION, | |
| PushInProgress, | |
| PushToHubMixin, | |
| can_return_loss, | |
| find_labels, | |
| is_accelerate_available, | |
| is_apex_available, | |
| is_bitsandbytes_available, | |
| is_datasets_available, | |
| is_galore_torch_available, | |
| is_grokadamw_available, | |
| is_in_notebook, | |
| is_ipex_available, | |
| is_liger_kernel_available, | |
| is_lomo_available, | |
| is_peft_available, | |
| is_safetensors_available, | |
| is_sagemaker_dp_enabled, | |
| is_sagemaker_mp_enabled, | |
| is_schedulefree_available, | |
| is_torch_compile_available, | |
| is_torch_mlu_available, | |
| is_torch_mps_available, | |
| is_torch_musa_available, | |
| is_torch_neuroncore_available, | |
| is_torch_npu_available, | |
| is_torch_xla_available, | |
| is_torch_xpu_available, | |
| is_torchao_available, | |
| logging, | |
| strtobool, | |
| ) | |
| from transformers.utils.deprecation import deprecate_kwarg | |
| from transformers.utils.quantization_config import QuantizationMethod | |
| DEFAULT_CALLBACKS = [DefaultFlowCallback] | |
| DEFAULT_PROGRESS_CALLBACK = ProgressCallback | |
| if is_in_notebook(): | |
| from transformers.utils.notebook import NotebookProgressCallback | |
| DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback | |
| if is_apex_available(): | |
| from apex import amp | |
| if is_datasets_available(): | |
| import datasets | |
| if is_torch_xla_available(): | |
| import torch_xla.core.xla_model as xm | |
| import torch_xla.debug.metrics as met | |
| from torch_xla import __version__ as XLA_VERSION | |
| IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) | |
| if IS_XLA_FSDPV2_POST_2_2: | |
| import torch_xla.distributed.spmd as xs | |
| import torch_xla.runtime as xr | |
| else: | |
| IS_XLA_FSDPV2_POST_2_2 = False | |
| if is_sagemaker_mp_enabled(): | |
| import smdistributed.modelparallel.torch as smp | |
| from smdistributed.modelparallel import __version__ as SMP_VERSION | |
| IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") | |
| from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat | |
| else: | |
| IS_SAGEMAKER_MP_POST_1_10 = False | |
| if is_safetensors_available(): | |
| import safetensors.torch | |
| if is_peft_available(): | |
| from peft import PeftModel | |
| if is_accelerate_available(): | |
| from accelerate import Accelerator, skip_first_batches | |
| from accelerate import __version__ as accelerate_version | |
| from accelerate.state import AcceleratorState | |
| from accelerate.utils import ( | |
| DistributedDataParallelKwargs, | |
| DistributedType, | |
| load_fsdp_model, | |
| load_fsdp_optimizer, | |
| save_fsdp_model, | |
| save_fsdp_optimizer, | |
| ) | |
| DATA_SAMPLERS = [RandomSampler] | |
| if version.parse(accelerate_version) > version.parse("0.23.0"): | |
| from accelerate.data_loader import SeedableRandomSampler | |
| DATA_SAMPLERS += [SeedableRandomSampler] | |
| if is_deepspeed_available(): | |
| from accelerate.utils import DeepSpeedSchedulerWrapper | |
| if is_accelerate_available("0.28.0"): | |
| from accelerate.utils import DataLoaderConfiguration | |
| def _is_peft_model(model): | |
| if is_peft_available(): | |
| classes_to_check = (PeftModel,) if is_peft_available() else () | |
| # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 | |
| if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): | |
| from peft import PeftMixedModel | |
| classes_to_check = (*classes_to_check, PeftMixedModel) | |
| return isinstance(model, classes_to_check) | |
| return False | |
| def _get_fsdp_ckpt_kwargs(): | |
| # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release | |
| if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): | |
| return {"adapter_only": True} | |
| else: | |
| return {} | |
| def safe_globals(): | |
| # Starting from version 2.4 PyTorch introduces a check for the objects loaded | |
| # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes | |
| # a default and requires allowlisting of objects being loaded. | |
| # See: https://github.com/pytorch/pytorch/pull/137602 | |
| # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals | |
| # See: https://github.com/huggingface/accelerate/pull/3036 | |
| if version.parse(torch.__version__).release < version.parse("2.6").release: | |
| return contextlib.nullcontext() | |
| np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core | |
| allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] | |
| # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for | |
| # all versions of numpy | |
| allowlist += [type(np.dtype(np.uint32))] | |
| return torch.serialization.safe_globals(allowlist) | |
| if TYPE_CHECKING: | |
| import optuna | |
| if is_datasets_available(): | |
| import datasets | |
| logger = logging.get_logger(__name__) | |
| logger.setLevel("INFO") | |
| # Name of the files used for checkpointing | |
| TRAINING_ARGS_NAME = "training_args.bin" | |
| TRAINER_STATE_NAME = "trainer_state.json" | |
| OPTIMIZER_NAME = "optimizer.pt" | |
| OPTIMIZER_NAME_BIN = "optimizer.bin" | |
| SCHEDULER_NAME = "scheduler.pt" | |
| SCALER_NAME = "scaler.pt" | |
| FSDP_MODEL_NAME = "pytorch_model_fsdp" | |
| DATA_PRINT_ONCE = True | |
| BATCH = None | |
| def print_batch(batch, tokenizer, args): | |
| global DATA_PRINT_ONCE | |
| global BATCH | |
| if batch is not None: | |
| BATCH = batch | |
| else: | |
| batch = BATCH | |
| DATA_PRINT_ONCE = True | |
| if batch is None: | |
| return | |
| if DATA_PRINT_ONCE: | |
| global_rank = torch.distributed.get_rank() | |
| f = open(os.path.join(args.output_dir, f"print_batch_{global_rank}.log"), "a") | |
| torch.set_printoptions(threshold=100_000) | |
| if "loss_mask" in batch and batch["loss_mask"] is not None: | |
| loss_mask = batch["loss_mask"] | |
| print(f"loss_mask {loss_mask} {loss_mask.size()}", file=f) | |
| if "position_ids" in batch and batch["position_ids"] is not None: | |
| position_ids = batch["position_ids"] | |
| print(f"position_ids {position_ids} {position_ids.size()}", file=f) | |
| if "attention_mask" in batch and batch["attention_mask"] is not None: | |
| attention_mask = batch["attention_mask"] | |
| if isinstance(attention_mask, list): | |
| attention_mask = attention_mask[0] | |
| print(f"attention_mask {attention_mask} {attention_mask.size()}", file=f) | |
| if "input_ids" in batch and batch["input_ids"] is not None: | |
| tokens = batch["input_ids"] | |
| print(f"tokens {tokens} {tokens.size()}", file=f) | |
| tokens_ = tokens.cpu().clone().detach() | |
| tokens_ = tokenizer.batch_decode(tokens_.tolist(), skip_special_tokens=False) | |
| print(f"tokens_ {tokens_[:]}", file=f) | |
| if "labels" in batch and batch["labels"] is not None: | |
| labels = batch["labels"] | |
| print(f"labels {labels} {labels.size()}", file=f) | |
| labels_ = labels.cpu().clone().detach() | |
| labels_[labels_==-100] = tokenizer("-", add_special_tokens=False).input_ids[0] | |
| labels_ = tokenizer.batch_decode(labels_.tolist(), skip_special_tokens=False) | |
| print(f"labels {labels_}", file=f) | |
| # labels__ = labels.cpu().clone().detach() | |
| # labels__[loss_mask.to(torch.int64)==0] = tokenizer("-", add_special_tokens=False).input_ids[0] | |
| # labels__ = tokenizer.batch_decode(labels__.tolist(), skip_special_tokens=False) | |
| # print(f"labels__ {labels__}", file=f) | |
| for k, v in batch.items(): | |
| if isinstance(v, torch.Tensor): | |
| print(f"{k} {v} {v.size()}", file=f) | |
| else: | |
| print(f"{k} {v}", file=f) | |
| f.close() | |
| DATA_PRINT_ONCE = False | |
| from transformers import Trainer as HFTrainer | |
| class Trainer(HFTrainer): | |
| def get_train_dataloader(self) -> DataLoader: | |
| """ | |
| Returns the training [`~torch.utils.data.DataLoader`]. | |
| Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed | |
| training if necessary) otherwise. | |
| Subclass and override this method if you want to inject some custom behavior. | |
| """ | |
| if self.train_dataset is None: | |
| raise ValueError("Trainer: training requires a train_dataset.") | |
| train_dataset = self.train_dataset | |
| data_collator = self.data_collator | |
| if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): | |
| train_dataset = self._remove_unused_columns(train_dataset, description="training") | |
| else: | |
| data_collator = self._get_collator_with_removed_columns(data_collator, description="training") | |
| dataloader_params = { | |
| "batch_size": self._train_batch_size, | |
| "collate_fn": data_collator, | |
| "num_workers": self.args.dataloader_num_workers, | |
| "pin_memory": self.args.dataloader_pin_memory, | |
| "persistent_workers": self.args.dataloader_persistent_workers, | |
| "multiprocessing_context": "spawn", | |
| } | |
| if not isinstance(train_dataset, torch.utils.data.IterableDataset): | |
| dataloader_params["sampler"] = self._get_train_sampler() | |
| dataloader_params["drop_last"] = self.args.dataloader_drop_last | |
| dataloader_params["worker_init_fn"] = seed_worker | |
| dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor | |
| return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) | |
| def create_optimizer(self): | |
| """ | |
| Setup the optimizer. | |
| We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the | |
| Trainer's init through `optimizers`, or subclass and override this method in a subclass. | |
| """ | |
| opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model | |
| if self.optimizer is None: | |
| decay_parameters = self.get_decay_parameter_names(opt_model) | |
| if self.args.vision_model_lr_mult != 1.0 or self.args.vision_model_lr_decay_rate != 1.0: | |
| vision_parameters = [name for name, _ in opt_model.named_parameters() if "vision_model" in name] | |
| logger.info(f"{vision_parameters=}") | |
| else: | |
| vision_parameters = [] | |
| if self.args.mtp_model_lr_mult != 1.0: | |
| mtp_parameters = [] | |
| mtp_names = ["mtp"] | |
| num_nextn_predict_layers = self.model.config.num_nextn_predict_layers | |
| num_hidden_layers = self.model.config.num_hidden_layers | |
| for mtp_idx in range(num_nextn_predict_layers): | |
| layer_idx = num_hidden_layers - num_nextn_predict_layers + mtp_idx | |
| mtp_names.append(f"model.layers.{layer_idx}") | |
| for name, param in opt_model.named_parameters(): | |
| if any([x in name for x in mtp_names]): | |
| mtp_parameters.append(name) | |
| logger.info(f"{mtp_parameters=}") | |
| else: | |
| mtp_parameters = [] | |
| exclude_parameters = vision_parameters + mtp_parameters | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n not in exclude_parameters) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n not in exclude_parameters) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| if self.args.vision_model_lr_decay_rate != 1.0: | |
| for n, p in opt_model.named_parameters(): | |
| if p.requires_grad and n in vision_parameters: | |
| pass | |
| else: | |
| continue | |
| if n in decay_parameters: | |
| weight_decay = self.args.weight_decay | |
| else: | |
| weight_decay = 0.0 | |
| lr = self.args.learning_rate * get_vit_lr_decay_rate(n, opt_model.config.visual.num_hidden_layers, self.args.vision_model_lr_decay_rate) | |
| optimizer_grouped_parameters.append( | |
| { | |
| "params": [p], | |
| "weight_decay": weight_decay, | |
| "lr": lr, | |
| } | |
| ) | |
| logger.info(f"create_optimizer name {n} weight_decay {weight_decay} lr {lr}") | |
| elif self.args.vision_model_lr_mult != 1.0: | |
| optimizer_grouped_parameters.extend( | |
| [ | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| "lr": self.args.learning_rate * self.args.vision_model_lr_mult, | |
| }, | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters) | |
| ], | |
| "weight_decay": 0.0, | |
| "lr": self.args.learning_rate * self.args.vision_model_lr_mult, | |
| }, | |
| ] | |
| ) | |
| logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.vision_model_lr_mult}") | |
| logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in vision_parameters)]} weight_decay {0.0} lr_mult {self.args.vision_model_lr_mult}") | |
| if self.args.mtp_model_lr_mult != 1.0: | |
| optimizer_grouped_parameters.extend( | |
| [ | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters) | |
| ], | |
| "weight_decay": self.args.weight_decay, | |
| "lr": self.args.learning_rate * self.args.mtp_model_lr_mult, | |
| }, | |
| { | |
| "params": [ | |
| p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters) | |
| ], | |
| "weight_decay": 0.0, | |
| "lr": self.args.learning_rate * self.args.mtp_model_lr_mult, | |
| }, | |
| ] | |
| ) | |
| logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {self.args.weight_decay} lr_mult {self.args.mtp_model_lr_mult}") | |
| logger.info(f"create_optimizer name {[n for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad and n in mtp_parameters)]} weight_decay {0.0} lr_mult {self.args.mtp_model_lr_mult}") | |
| if self.optimizer_cls_and_kwargs is not None: | |
| optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs | |
| else: | |
| optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, opt_model) | |
| # Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs` | |
| # e.g. for GaLore optimizer. | |
| if "params" in optimizer_kwargs: | |
| optimizer_grouped_parameters = optimizer_kwargs.pop("params") | |
| # Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs` | |
| # e.g. for LOMO optimizer. | |
| if "model" in optimizer_kwargs: | |
| optimizer_grouped_parameters = optimizer_kwargs.pop("model") | |
| # For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict` | |
| # to avoid arguments conflicts. | |
| if "optimizer_dict" in optimizer_kwargs: | |
| optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict") | |
| self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) | |
| if optimizer_cls.__name__ == "Adam8bit": | |
| import bitsandbytes | |
| manager = bitsandbytes.optim.GlobalOptimManager.get_instance() | |
| skipped = 0 | |
| for module in opt_model.modules(): | |
| if isinstance(module, nn.Embedding): | |
| skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) | |
| logger.info(f"skipped {module}: {skipped/2**20}M params") | |
| manager.register_module_override(module, "weight", {"optim_bits": 32}) | |
| logger.debug(f"bitsandbytes: will optimize {module} in fp32") | |
| logger.info(f"skipped: {skipped/2**20}M params") | |
| if is_sagemaker_mp_enabled(): | |
| self.optimizer = smp.DistributedOptimizer(self.optimizer) | |
| return self.optimizer | |
| def training_step( | |
| self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None | |
| ) -> torch.Tensor: | |
| """ | |
| Perform a training step on a batch of inputs. | |
| Subclass and override to inject custom behavior. | |
| Args: | |
| model (`nn.Module`): | |
| The model to train. | |
| inputs (`Dict[str, Union[torch.Tensor, Any]]`): | |
| The inputs and targets of the model. | |
| The dictionary will be unpacked before being fed to the model. Most models expect the targets under the | |
| argument `labels`. Check your model's documentation for all accepted arguments. | |
| Return: | |
| `torch.Tensor`: The tensor with training loss on this batch. | |
| """ | |
| print_batch(inputs, self.processing_class, self.args) | |
| model.train() | |
| if hasattr(self.optimizer, "train") and callable(self.optimizer.train): | |
| self.optimizer.train() | |
| inputs = self._prepare_inputs(inputs) | |
| if is_sagemaker_mp_enabled(): | |
| loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) | |
| return loss_mb.reduce_mean().detach().to(self.args.device) | |
| with self.compute_loss_context_manager(): | |
| loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch) | |
| del inputs | |
| if ( | |
| self.args.torch_empty_cache_steps is not None | |
| and self.state.global_step % self.args.torch_empty_cache_steps == 0 | |
| ): | |
| if is_torch_xpu_available(): | |
| torch.xpu.empty_cache() | |
| elif is_torch_mlu_available(): | |
| torch.mlu.empty_cache() | |
| elif is_torch_musa_available(): | |
| torch.musa.empty_cache() | |
| elif is_torch_npu_available(): | |
| torch.npu.empty_cache() | |
| elif is_torch_mps_available(min_version="2.0"): | |
| torch.mps.empty_cache() | |
| elif is_torch_hpu_available(): | |
| logger.warning( | |
| "`torch_empty_cache_steps` is set but HPU device/backend does not support empty_cache()." | |
| ) | |
| else: | |
| torch.cuda.empty_cache() | |
| kwargs = {} | |
| # For LOMO optimizers you need to explicitly use the learnign rate | |
| if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: | |
| kwargs["learning_rate"] = self._get_learning_rate() | |
| if self.args.n_gpu > 1: | |
| loss = loss.mean() # mean() to average on multi-gpu parallel training | |
| if self.use_apex: | |
| with amp.scale_loss(loss, self.optimizer) as scaled_loss: | |
| scaled_loss.backward() | |
| else: | |
| # Finally we need to normalize the loss for reporting | |
| if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: | |
| loss = loss / self.args.gradient_accumulation_steps | |
| # Turning off loss scaling w.r.t. gradient accumulation when DeepSpeed is enabled | |
| # https://github.com/huggingface/transformers/pull/35808 | |
| if self.accelerator.distributed_type == DistributedType.DEEPSPEED: | |
| kwargs["scale_wrt_gas"] = False | |
| self.accelerator.backward(loss, **kwargs) | |
| return loss.detach() | |
| def get_batch_samples(self, epoch_iterator, num_batches): | |
| batch_samples = [] | |
| num_items_in_batch = None | |
| for _ in range(num_batches): | |
| try: | |
| while True: | |
| batch_sample = next(epoch_iterator) | |
| if "input_ids" in batch_sample: | |
| break | |
| batch_samples += [batch_sample] | |
| except StopIteration: | |
| break | |
| if len(batch_samples) > 0 and "labels" in batch_samples[0]: | |
| # For now we don't support object detection | |
| try: | |
| num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples]) | |
| except (TypeError, AttributeError): | |
| pass | |
| if self.args.average_tokens_across_devices and num_items_in_batch is not None: | |
| num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() | |
| if torch.is_tensor(num_items_in_batch): | |
| num_items_in_batch = num_items_in_batch.item() | |
| return batch_samples, num_items_in_batch | |
| def get_vit_lr_decay_rate(name, num_layers, lr_decay_rate): | |
| layer_id = num_layers + 1 | |
| if "vision_model." in name: | |
| if ".position_embedding." in name or ".conv1." in name: | |
| layer_id = 0 | |
| elif ".layers." in name: | |
| layer_id = int(name[name.find(".layers.") :].split(".")[2]) + 1 | |
| return lr_decay_rate ** (num_layers + 1 - layer_id) | |