Spaces:
Runtime error
Runtime error
| """ | |
| Trainer Class | |
| ============= | |
| """ | |
| import collections | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import scipy | |
| import torch | |
| import tqdm | |
| import transformers | |
| import textattack | |
| from textattack.shared.utils import logger | |
| from .attack import Attack | |
| from .attack_args import AttackArgs | |
| from .attack_results import MaximizedAttackResult, SuccessfulAttackResult | |
| from .attacker import Attacker | |
| from .model_args import HUGGINGFACE_MODELS | |
| from .models.helpers import LSTMForClassification, WordCNNForClassification | |
| from .models.wrappers import ModelWrapper | |
| from .training_args import CommandLineTrainingArgs, TrainingArgs | |
| class Trainer: | |
| """Trainer is training and eval loop for adversarial training. | |
| It is designed to work with PyTorch and Transformers models. | |
| Args: | |
| model_wrapper (:class:`~textattack.models.wrappers.ModelWrapper`): | |
| Model wrapper containing both the model and the tokenizer. | |
| task_type (:obj:`str`, `optional`, defaults to :obj:`"classification"`): | |
| The task that the model is trained to perform. | |
| Currently, :class:`~textattack.Trainer` supports two tasks: (1) :obj:`"classification"`, (2) :obj:`"regression"`. | |
| attack (:class:`~textattack.Attack`): | |
| :class:`~textattack.Attack` used to generate adversarial examples for training. | |
| train_dataset (:class:`~textattack.datasets.Dataset`): | |
| Dataset for training. | |
| eval_dataset (:class:`~textattack.datasets.Dataset`): | |
| Dataset for evaluation | |
| training_args (:class:`~textattack.TrainingArgs`): | |
| Arguments for training. | |
| Example:: | |
| >>> import textattack | |
| >>> import transformers | |
| >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased") | |
| >>> tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") | |
| >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) | |
| >>> # We only use DeepWordBugGao2018 to demonstration purposes. | |
| >>> attack = textattack.attack_recipes.DeepWordBugGao2018.build(model_wrapper) | |
| >>> train_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="train") | |
| >>> eval_dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") | |
| >>> # Train for 3 epochs with 1 initial clean epochs, 1000 adversarial examples per epoch, learning rate of 5e-5, and effective batch size of 32 (8x4). | |
| >>> training_args = textattack.TrainingArgs( | |
| ... num_epochs=3, | |
| ... num_clean_epochs=1, | |
| ... num_train_adv_examples=1000, | |
| ... learning_rate=5e-5, | |
| ... per_device_train_batch_size=8, | |
| ... gradient_accumulation_steps=4, | |
| ... log_to_tb=True, | |
| ... ) | |
| >>> trainer = textattack.Trainer( | |
| ... model_wrapper, | |
| ... "classification", | |
| ... attack, | |
| ... train_dataset, | |
| ... eval_dataset, | |
| ... training_args | |
| ... ) | |
| >>> trainer.train() | |
| .. note:: | |
| When using :class:`~textattack.Trainer` with `parallel=True` in :class:`~textattack.TrainingArgs`, | |
| make sure to protect the “entry point” of the program by using :obj:`if __name__ == '__main__':`. | |
| If not, each worker process used for generating adversarial examples will execute the training code again. | |
| """ | |
| def __init__( | |
| self, | |
| model_wrapper, | |
| task_type="classification", | |
| attack=None, | |
| train_dataset=None, | |
| eval_dataset=None, | |
| training_args=None, | |
| ): | |
| assert isinstance( | |
| model_wrapper, ModelWrapper | |
| ), f"`model_wrapper` must be of type `textattack.models.wrappers.ModelWrapper`, but got type `{type(model_wrapper)}`." | |
| # TODO: Support seq2seq training | |
| assert task_type in { | |
| "classification", | |
| "regression", | |
| }, '`task_type` must either be "classification" or "regression"' | |
| if attack: | |
| assert isinstance( | |
| attack, Attack | |
| ), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`." | |
| if id(model_wrapper) != id(attack.goal_function.model): | |
| logger.warn( | |
| "`model_wrapper` and the victim model of `attack` are not the same model." | |
| ) | |
| if train_dataset: | |
| assert isinstance( | |
| train_dataset, textattack.datasets.Dataset | |
| ), f"`train_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(train_dataset)}`." | |
| if eval_dataset: | |
| assert isinstance( | |
| eval_dataset, textattack.datasets.Dataset | |
| ), f"`eval_dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(eval_dataset)}`." | |
| if training_args: | |
| assert isinstance( | |
| training_args, TrainingArgs | |
| ), f"`training_args` must be of type `textattack.TrainingArgs`, but got type `{type(training_args)}`." | |
| else: | |
| training_args = TrainingArgs() | |
| if not hasattr(model_wrapper, "model"): | |
| raise ValueError("Cannot detect `model` in `model_wrapper`") | |
| else: | |
| assert isinstance( | |
| model_wrapper.model, torch.nn.Module | |
| ), f"`model` in `model_wrapper` must be of type `torch.nn.Module`, but got type `{type(model_wrapper.model)}`." | |
| if not hasattr(model_wrapper, "tokenizer"): | |
| raise ValueError("Cannot detect `tokenizer` in `model_wrapper`") | |
| self.model_wrapper = model_wrapper | |
| self.task_type = task_type | |
| self.attack = attack | |
| self.train_dataset = train_dataset | |
| self.eval_dataset = eval_dataset | |
| self.training_args = training_args | |
| self._metric_name = ( | |
| "pearson_correlation" if self.task_type == "regression" else "accuracy" | |
| ) | |
| if self.task_type == "regression": | |
| self.loss_fct = torch.nn.MSELoss(reduction="none") | |
| else: | |
| self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") | |
| self._global_step = 0 | |
| def _generate_adversarial_examples(self, epoch): | |
| """Generate adversarial examples using attacker.""" | |
| assert ( | |
| self.attack is not None | |
| ), "`attack` is `None` but attempting to generate adversarial examples." | |
| base_file_name = f"attack-train-{epoch}" | |
| log_file_name = os.path.join(self.training_args.output_dir, base_file_name) | |
| logger.info("Attacking model to generate new adversarial training set...") | |
| if isinstance(self.training_args.num_train_adv_examples, float): | |
| num_train_adv_examples = math.ceil( | |
| len(self.train_dataset) * self.training_args.num_train_adv_examples | |
| ) | |
| else: | |
| num_train_adv_examples = self.training_args.num_train_adv_examples | |
| # Use Different AttackArgs based on num_train_adv_examples value. | |
| # If num_train_adv_examples >= 0 , num_train_adv_examples is | |
| # set as number of successful examples. | |
| # If num_train_adv_examples == -1 , num_examples is set to -1 to | |
| # generate example for all of training data. | |
| if num_train_adv_examples >= 0: | |
| attack_args = AttackArgs( | |
| num_successful_examples=num_train_adv_examples, | |
| num_examples_offset=0, | |
| query_budget=self.training_args.query_budget_train, | |
| shuffle=True, | |
| parallel=self.training_args.parallel, | |
| num_workers_per_device=self.training_args.attack_num_workers_per_device, | |
| disable_stdout=True, | |
| silent=True, | |
| log_to_txt=log_file_name + ".txt", | |
| log_to_csv=log_file_name + ".csv", | |
| ) | |
| elif num_train_adv_examples == -1: | |
| # set num_examples when num_train_adv_examples = -1 | |
| attack_args = AttackArgs( | |
| num_examples=num_train_adv_examples, | |
| num_examples_offset=0, | |
| query_budget=self.training_args.query_budget_train, | |
| shuffle=True, | |
| parallel=self.training_args.parallel, | |
| num_workers_per_device=self.training_args.attack_num_workers_per_device, | |
| disable_stdout=True, | |
| silent=True, | |
| log_to_txt=log_file_name + ".txt", | |
| log_to_csv=log_file_name + ".csv", | |
| ) | |
| else: | |
| assert False, "num_train_adv_examples is negative and not equal to -1." | |
| attacker = Attacker(self.attack, self.train_dataset, attack_args=attack_args) | |
| results = attacker.attack_dataset() | |
| attack_types = collections.Counter(r.__class__.__name__ for r in results) | |
| total_attacks = ( | |
| attack_types["SuccessfulAttackResult"] + attack_types["FailedAttackResult"] | |
| ) | |
| success_rate = attack_types["SuccessfulAttackResult"] / total_attacks * 100 | |
| logger.info(f"Total number of attack results: {len(results)}") | |
| logger.info( | |
| f"Attack success rate: {success_rate:.2f}% [{attack_types['SuccessfulAttackResult']} / {total_attacks}]" | |
| ) | |
| # TODO: This will produce a bug if we need to manipulate ground truth output. | |
| # To Fix Issue #498 , We need to add the Non Output columns in one tuple to represent input columns | |
| # Since adversarial_example won't be an input to the model , we will have to remove it from the input | |
| # dictionary in collate_fn | |
| adversarial_examples = [ | |
| ( | |
| tuple(r.perturbed_result.attacked_text._text_input.values()) | |
| + ("adversarial_example",), | |
| r.perturbed_result.ground_truth_output, | |
| ) | |
| for r in results | |
| if isinstance(r, (SuccessfulAttackResult, MaximizedAttackResult)) | |
| ] | |
| # Name for column indicating if an example is adversarial is set as "_example_type". | |
| adversarial_dataset = textattack.datasets.Dataset( | |
| adversarial_examples, | |
| input_columns=self.train_dataset.input_columns + ("_example_type",), | |
| label_map=self.train_dataset.label_map, | |
| label_names=self.train_dataset.label_names, | |
| output_scale_factor=self.train_dataset.output_scale_factor, | |
| shuffle=False, | |
| ) | |
| return adversarial_dataset | |
| def _print_training_args( | |
| self, total_training_steps, train_batch_size, num_clean_epochs | |
| ): | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(self.train_dataset)}") | |
| logger.info(f" Num epochs = {self.training_args.num_epochs}") | |
| logger.info(f" Num clean epochs = {num_clean_epochs}") | |
| logger.info( | |
| f" Instantaneous batch size per device = {self.training_args.per_device_train_batch_size}" | |
| ) | |
| logger.info( | |
| f" Total train batch size (w. parallel, distributed & accumulation) = {train_batch_size * self.training_args.gradient_accumulation_steps}" | |
| ) | |
| logger.info( | |
| f" Gradient accumulation steps = {self.training_args.gradient_accumulation_steps}" | |
| ) | |
| logger.info(f" Total optimization steps = {total_training_steps}") | |
| def _save_model_checkpoint( | |
| self, model, tokenizer, step=None, epoch=None, best=False, last=False | |
| ): | |
| # Save model checkpoint | |
| if step: | |
| dir_name = f"checkpoint-step-{step}" | |
| if epoch: | |
| dir_name = f"checkpoint-epoch-{epoch}" | |
| if best: | |
| dir_name = "best_model" | |
| if last: | |
| dir_name = "last_model" | |
| output_dir = os.path.join(self.training_args.output_dir, dir_name) | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| if isinstance(model, torch.nn.DataParallel): | |
| model = model.module | |
| if isinstance(model, (WordCNNForClassification, LSTMForClassification)): | |
| model.save_pretrained(output_dir) | |
| elif isinstance(model, transformers.PreTrainedModel): | |
| model.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| else: | |
| state_dict = {k: v.cpu() for k, v in model.state_dict().items()} | |
| torch.save( | |
| state_dict, | |
| os.path.join(output_dir, "pytorch_model.bin"), | |
| ) | |
| def _tb_log(self, log, step): | |
| if not hasattr(self, "_tb_writer"): | |
| from torch.utils.tensorboard import SummaryWriter | |
| self._tb_writer = SummaryWriter(self.training_args.tb_log_dir) | |
| self._tb_writer.add_hparams(self.training_args.__dict__, {}) | |
| self._tb_writer.flush() | |
| for key in log: | |
| self._tb_writer.add_scalar(key, log[key], step) | |
| def _wandb_log(self, log, step): | |
| if not hasattr(self, "_wandb_init"): | |
| global wandb | |
| import wandb | |
| self._wandb_init = True | |
| wandb.init( | |
| project=self.training_args.wandb_project, | |
| config=self.training_args.__dict__, | |
| ) | |
| wandb.log(log, step=step) | |
| def get_optimizer_and_scheduler(self, model, num_training_steps): | |
| """Returns optimizer and scheduler to use for training. If you are | |
| overriding this method and do not want to use a scheduler, simply | |
| return :obj:`None` for scheduler. | |
| Args: | |
| model (:obj:`torch.nn.Module`): | |
| Model to be trained. Pass its parameters to optimizer for training. | |
| num_training_steps (:obj:`int`): | |
| Number of total training steps. | |
| Returns: | |
| Tuple of optimizer and scheduler :obj:`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler._LRScheduler]` | |
| """ | |
| if isinstance(model, torch.nn.DataParallel): | |
| model = model.module | |
| if isinstance(model, transformers.PreTrainedModel): | |
| # Reference https://huggingface.co/transformers/training.html | |
| param_optimizer = list(model.named_parameters()) | |
| no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in param_optimizer | |
| if not any(nd in n for nd in no_decay) | |
| ], | |
| "weight_decay": self.training_args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p for n, p in param_optimizer if any(nd in n for nd in no_decay) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| optimizer = transformers.optimization.AdamW( | |
| optimizer_grouped_parameters, lr=self.training_args.learning_rate | |
| ) | |
| if isinstance(self.training_args.num_warmup_steps, float): | |
| num_warmup_steps = math.ceil( | |
| self.training_args.num_warmup_steps * num_training_steps | |
| ) | |
| else: | |
| num_warmup_steps = self.training_args.num_warmup_steps | |
| scheduler = transformers.optimization.get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=num_warmup_steps, | |
| num_training_steps=num_training_steps, | |
| ) | |
| else: | |
| optimizer = torch.optim.Adam( | |
| filter(lambda x: x.requires_grad, model.parameters()), | |
| lr=self.training_args.learning_rate, | |
| ) | |
| scheduler = None | |
| return optimizer, scheduler | |
| def get_train_dataloader(self, dataset, adv_dataset, batch_size): | |
| """Returns the :obj:`torch.utils.data.DataLoader` for training. | |
| Args: | |
| dataset (:class:`~textattack.datasets.Dataset`): | |
| Original training dataset. | |
| adv_dataset (:class:`~textattack.datasets.Dataset`): | |
| Adversarial examples generated from the original training dataset. :obj:`None` if no adversarial attack takes place. | |
| batch_size (:obj:`int`): | |
| Batch size for training. | |
| Returns: | |
| :obj:`torch.utils.data.DataLoader` | |
| """ | |
| # TODO: Add pairing option where we can pair original examples with adversarial examples. | |
| # Helper functions for collating data | |
| def collate_fn(data): | |
| input_texts = [] | |
| targets = [] | |
| is_adv_sample = [] | |
| for item in data: | |
| if "_example_type" in item[0].keys(): | |
| # Get example type value from OrderedDict and remove it | |
| adv = item[0].pop("_example_type") | |
| # with _example_type removed from item[0] OrderedDict | |
| # all other keys should be part of input | |
| _input, label = item | |
| if adv != "adversarial_example": | |
| raise ValueError( | |
| "`item` has length of 3 but last element is not for marking if the item is an `adversarial example`." | |
| ) | |
| else: | |
| is_adv_sample.append(True) | |
| else: | |
| # else `len(item)` is 2. | |
| _input, label = item | |
| is_adv_sample.append(False) | |
| if isinstance(_input, collections.OrderedDict): | |
| _input = tuple(_input.values()) | |
| else: | |
| _input = tuple(_input) | |
| if len(_input) == 1: | |
| _input = _input[0] | |
| input_texts.append(_input) | |
| targets.append(label) | |
| return input_texts, torch.tensor(targets), torch.tensor(is_adv_sample) | |
| if adv_dataset: | |
| dataset = torch.utils.data.ConcatDataset([dataset, adv_dataset]) | |
| train_dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| pin_memory=True, | |
| ) | |
| return train_dataloader | |
| def get_eval_dataloader(self, dataset, batch_size): | |
| """Returns the :obj:`torch.utils.data.DataLoader` for evaluation. | |
| Args: | |
| dataset (:class:`~textattack.datasets.Dataset`): | |
| Dataset to use for evaluation. | |
| batch_size (:obj:`int`): | |
| Batch size for evaluation. | |
| Returns: | |
| :obj:`torch.utils.data.DataLoader` | |
| """ | |
| # Helper functions for collating data | |
| def collate_fn(data): | |
| input_texts = [] | |
| targets = [] | |
| for _input, label in data: | |
| if isinstance(_input, collections.OrderedDict): | |
| _input = tuple(_input.values()) | |
| else: | |
| _input = tuple(_input) | |
| if len(_input) == 1: | |
| _input = _input[0] | |
| input_texts.append(_input) | |
| targets.append(label) | |
| return input_texts, torch.tensor(targets) | |
| eval_dataloader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=collate_fn, | |
| pin_memory=True, | |
| ) | |
| return eval_dataloader | |
| def training_step(self, model, tokenizer, batch): | |
| """Perform a single training step on a batch of inputs. | |
| Args: | |
| model (:obj:`torch.nn.Module`): | |
| Model to train. | |
| tokenizer: | |
| Tokenizer used to tokenize input text. | |
| batch (:obj:`tuple[list[str], torch.Tensor, torch.Tensor]`): | |
| By default, this will be a tuple of input texts, targets, and boolean tensor indicating if the sample is an adversarial example. | |
| .. note:: | |
| If you override the :meth:`get_train_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. | |
| Returns: | |
| :obj:`tuple[torch.Tensor, torch.Tensor, torch.Tensor]` where | |
| - **loss**: :obj:`torch.FloatTensor` of shape 1 containing the loss. | |
| - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. | |
| - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). | |
| """ | |
| input_texts, targets, is_adv_sample = batch | |
| _targets = targets | |
| targets = targets.to(textattack.shared.utils.device) | |
| if isinstance(model, transformers.PreTrainedModel) or ( | |
| isinstance(model, torch.nn.DataParallel) | |
| and isinstance(model.module, transformers.PreTrainedModel) | |
| ): | |
| input_ids = tokenizer( | |
| input_texts, | |
| padding="max_length", | |
| return_tensors="pt", | |
| truncation=True, | |
| ) | |
| input_ids.to(textattack.shared.utils.device) | |
| logits = model(**input_ids)[0] | |
| else: | |
| input_ids = tokenizer(input_texts) | |
| if not isinstance(input_ids, torch.Tensor): | |
| input_ids = torch.tensor(input_ids) | |
| input_ids = input_ids.to(textattack.shared.utils.device) | |
| logits = model(input_ids) | |
| if self.task_type == "regression": | |
| loss = self.loss_fct(logits.squeeze(), targets.squeeze()) | |
| preds = logits | |
| else: | |
| loss = self.loss_fct(logits, targets) | |
| preds = logits.argmax(dim=-1) | |
| sample_weights = torch.ones( | |
| is_adv_sample.size(), device=textattack.shared.utils.device | |
| ) | |
| sample_weights[is_adv_sample] *= self.training_args.alpha | |
| loss = loss * sample_weights | |
| loss = torch.mean(loss) | |
| preds = preds.cpu() | |
| return loss, preds, _targets | |
| def evaluate_step(self, model, tokenizer, batch): | |
| """Perform a single evaluation step on a batch of inputs. | |
| Args: | |
| model (:obj:`torch.nn.Module`): | |
| Model to train. | |
| tokenizer: | |
| Tokenizer used to tokenize input text. | |
| batch (:obj:`tuple[list[str], torch.Tensor]`): | |
| By default, this will be a tuple of input texts and target tensors. | |
| .. note:: | |
| If you override the :meth:`get_eval_dataloader` method, then shape/type of :obj:`batch` will depend on how you created your batch. | |
| Returns: | |
| :obj:`tuple[torch.Tensor, torch.Tensor]` where | |
| - **preds**: :obj:`torch.FloatTensor` of model's prediction for the batch. | |
| - **targets**: :obj:`torch.Tensor` of model's targets (e.g. labels, target values). | |
| """ | |
| input_texts, targets = batch | |
| _targets = targets | |
| targets = targets.to(textattack.shared.utils.device) | |
| if isinstance(model, transformers.PreTrainedModel): | |
| input_ids = tokenizer( | |
| input_texts, | |
| padding="max_length", | |
| return_tensors="pt", | |
| truncation=True, | |
| ) | |
| input_ids.to(textattack.shared.utils.device) | |
| logits = model(**input_ids)[0] | |
| else: | |
| input_ids = tokenizer(input_texts) | |
| if not isinstance(input_ids, torch.Tensor): | |
| input_ids = torch.tensor(input_ids) | |
| input_ids = input_ids.to(textattack.shared.utils.device) | |
| logits = model(input_ids) | |
| if self.task_type == "regression": | |
| preds = logits | |
| else: | |
| preds = logits.argmax(dim=-1) | |
| return preds.cpu(), _targets | |
| def train(self): | |
| """Train the model on given training dataset.""" | |
| if not self.train_dataset: | |
| raise ValueError("No `train_dataset` available for training.") | |
| textattack.shared.utils.set_seed(self.training_args.random_seed) | |
| if not os.path.exists(self.training_args.output_dir): | |
| os.makedirs(self.training_args.output_dir) | |
| # Save logger writes to file | |
| log_txt_path = os.path.join(self.training_args.output_dir, "train_log.txt") | |
| fh = logging.FileHandler(log_txt_path) | |
| fh.setLevel(logging.DEBUG) | |
| logger.addHandler(fh) | |
| logger.info(f"Writing logs to {log_txt_path}.") | |
| # Save original self.training_args to file | |
| args_save_path = os.path.join( | |
| self.training_args.output_dir, "training_args.json" | |
| ) | |
| with open(args_save_path, "w", encoding="utf-8") as f: | |
| json.dump(self.training_args.__dict__, f) | |
| logger.info(f"Wrote original training args to {args_save_path}.") | |
| num_gpus = torch.cuda.device_count() | |
| tokenizer = self.model_wrapper.tokenizer | |
| model = self.model_wrapper.model | |
| if self.training_args.parallel and num_gpus > 1: | |
| # TODO: torch.nn.parallel.DistributedDataParallel | |
| # Supposedly faster than DataParallel, but requires more work to setup properly. | |
| model = torch.nn.DataParallel(model) | |
| logger.info(f"Training on {num_gpus} GPUs via `torch.nn.DataParallel`.") | |
| train_batch_size = self.training_args.per_device_train_batch_size * num_gpus | |
| else: | |
| train_batch_size = self.training_args.per_device_train_batch_size | |
| if self.attack is None: | |
| num_clean_epochs = self.training_args.num_epochs | |
| else: | |
| num_clean_epochs = self.training_args.num_clean_epochs | |
| total_clean_training_steps = ( | |
| math.ceil( | |
| len(self.train_dataset) | |
| / (train_batch_size * self.training_args.gradient_accumulation_steps) | |
| ) | |
| * num_clean_epochs | |
| ) | |
| # calculate total_adv_training_data_length based on type of | |
| # num_train_adv_examples. | |
| # if num_train_adv_examples is float , num_train_adv_examples is a portion of train_dataset. | |
| if isinstance(self.training_args.num_train_adv_examples, float): | |
| total_adv_training_data_length = ( | |
| len(self.train_dataset) * self.training_args.num_train_adv_examples | |
| ) | |
| # if num_train_adv_examples is int and >=0 then it is taken as value. | |
| elif ( | |
| isinstance(self.training_args.num_train_adv_examples, int) | |
| and self.training_args.num_train_adv_examples >= 0 | |
| ): | |
| total_adv_training_data_length = self.training_args.num_train_adv_examples | |
| # if num_train_adv_examples is = -1 , we generate all possible adv examples. | |
| # Max number of all possible adv examples would be equal to train_dataset. | |
| else: | |
| total_adv_training_data_length = len(self.train_dataset) | |
| # Based on total_adv_training_data_length calculation , find total total_adv_training_steps | |
| total_adv_training_steps = math.ceil( | |
| (len(self.train_dataset) + total_adv_training_data_length) | |
| / (train_batch_size * self.training_args.gradient_accumulation_steps) | |
| ) * (self.training_args.num_epochs - num_clean_epochs) | |
| total_training_steps = total_clean_training_steps + total_adv_training_steps | |
| optimizer, scheduler = self.get_optimizer_and_scheduler( | |
| model, total_training_steps | |
| ) | |
| self._print_training_args( | |
| total_training_steps, train_batch_size, num_clean_epochs | |
| ) | |
| model.to(textattack.shared.utils.device) | |
| # Variables across epochs | |
| self._total_loss = 0.0 | |
| self._current_loss = 0.0 | |
| self._last_log_step = 0 | |
| # `best_score` is used to keep track of the best model across training. | |
| # Could be loss, accuracy, or other metrics. | |
| best_eval_score = 0.0 | |
| best_eval_score_epoch = 0 | |
| best_model_path = None | |
| epochs_since_best_eval_score = 0 | |
| for epoch in range(1, self.training_args.num_epochs + 1): | |
| logger.info("==========================================================") | |
| logger.info(f"Epoch {epoch}") | |
| if self.attack and epoch > num_clean_epochs: | |
| if ( | |
| epoch - num_clean_epochs - 1 | |
| ) % self.training_args.attack_epoch_interval == 0: | |
| # only generate a new adversarial training set every self.training_args.attack_period epochs after the clean epochs | |
| # adv_dataset is instance of `textattack.datasets.Dataset` | |
| model.eval() | |
| adv_dataset = self._generate_adversarial_examples(epoch) | |
| model.train() | |
| model.to(textattack.shared.utils.device) | |
| else: | |
| adv_dataset = None | |
| else: | |
| logger.info(f"Running clean epoch {epoch}/{num_clean_epochs}") | |
| adv_dataset = None | |
| train_dataloader = self.get_train_dataloader( | |
| self.train_dataset, adv_dataset, train_batch_size | |
| ) | |
| model.train() | |
| # Epoch variables | |
| all_preds = [] | |
| all_targets = [] | |
| prog_bar = tqdm.tqdm( | |
| train_dataloader, | |
| desc="Iteration", | |
| position=0, | |
| leave=True, | |
| dynamic_ncols=True, | |
| ) | |
| for step, batch in enumerate(prog_bar): | |
| loss, preds, targets = self.training_step(model, tokenizer, batch) | |
| if isinstance(model, torch.nn.DataParallel): | |
| loss = loss.mean() | |
| loss = loss / self.training_args.gradient_accumulation_steps | |
| loss.backward() | |
| loss = loss.item() | |
| self._total_loss += loss | |
| self._current_loss += loss | |
| all_preds.append(preds) | |
| all_targets.append(targets) | |
| if (step + 1) % self.training_args.gradient_accumulation_steps == 0: | |
| optimizer.step() | |
| if scheduler: | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| self._global_step += 1 | |
| if self._global_step > 0: | |
| prog_bar.set_description( | |
| f"Loss {self._total_loss/self._global_step:.5f}" | |
| ) | |
| # TODO: Better way to handle TB and Wandb logging | |
| if (self._global_step > 0) and ( | |
| self._global_step % self.training_args.logging_interval_step == 0 | |
| ): | |
| lr_to_log = ( | |
| scheduler.get_last_lr()[0] | |
| if scheduler | |
| else self.training_args.learning_rate | |
| ) | |
| if self._global_step - self._last_log_step >= 1: | |
| loss_to_log = round( | |
| self._current_loss | |
| / (self._global_step - self._last_log_step), | |
| 4, | |
| ) | |
| else: | |
| loss_to_log = round(self._current_loss, 4) | |
| log = {"train/loss": loss_to_log, "train/learning_rate": lr_to_log} | |
| if self.training_args.log_to_tb: | |
| self._tb_log(log, self._global_step) | |
| if self.training_args.log_to_wandb: | |
| self._wandb_log(log, self._global_step) | |
| self._current_loss = 0.0 | |
| self._last_log_step = self._global_step | |
| # Save model checkpoint to file. | |
| if self.training_args.checkpoint_interval_steps: | |
| if ( | |
| self._global_step > 0 | |
| and ( | |
| self._global_step | |
| % self.training_args.checkpoint_interval_steps | |
| ) | |
| == 0 | |
| ): | |
| self._save_model_checkpoint( | |
| model, tokenizer, step=self._global_step | |
| ) | |
| preds = torch.cat(all_preds) | |
| targets = torch.cat(all_targets) | |
| if self._metric_name == "accuracy": | |
| correct_predictions = (preds == targets).sum().item() | |
| accuracy = correct_predictions / len(targets) | |
| metric_log = {"train/train_accuracy": accuracy} | |
| logger.info(f"Train accuracy: {accuracy*100:.2f}%") | |
| else: | |
| pearson_correlation, pearson_pvalue = scipy.stats.pearsonr( | |
| preds, targets | |
| ) | |
| metric_log = { | |
| "train/pearson_correlation": pearson_correlation, | |
| "train/pearson_pvalue": pearson_pvalue, | |
| } | |
| logger.info(f"Train Pearson correlation: {pearson_correlation:.4f}%") | |
| if len(targets) > 0: | |
| if self.training_args.log_to_tb: | |
| self._tb_log(metric_log, epoch) | |
| if self.training_args.log_to_wandb: | |
| metric_log["epoch"] = epoch | |
| self._wandb_log(metric_log, self._global_step) | |
| # Evaluate after each epoch. | |
| eval_score = self.evaluate() | |
| if self.training_args.log_to_tb: | |
| self._tb_log({f"eval/{self._metric_name}": eval_score}, epoch) | |
| if self.training_args.log_to_wandb: | |
| self._wandb_log( | |
| {f"eval/{self._metric_name}": eval_score, "epoch": epoch}, | |
| self._global_step, | |
| ) | |
| if ( | |
| self.training_args.checkpoint_interval_epochs | |
| and (epoch % self.training_args.checkpoint_interval_epochs) == 0 | |
| ): | |
| self._save_model_checkpoint(model, tokenizer, epoch=epoch) | |
| if eval_score > best_eval_score: | |
| best_eval_score = eval_score | |
| best_eval_score_epoch = epoch | |
| epochs_since_best_eval_score = 0 | |
| self._save_model_checkpoint(model, tokenizer, best=True) | |
| logger.info( | |
| f"Best score found. Saved model to {self.training_args.output_dir}/best_model/" | |
| ) | |
| else: | |
| epochs_since_best_eval_score += 1 | |
| if self.training_args.early_stopping_epochs and ( | |
| epochs_since_best_eval_score | |
| > self.training_args.early_stopping_epochs | |
| ): | |
| logger.info( | |
| f"Stopping early since it's been {self.training_args.early_stopping_epochs} steps since validation score increased." | |
| ) | |
| break | |
| if self.training_args.log_to_tb: | |
| self._tb_writer.flush() | |
| # Finish training | |
| if isinstance(model, torch.nn.DataParallel): | |
| model = model.module | |
| if self.training_args.load_best_model_at_end: | |
| best_model_path = os.path.join(self.training_args.output_dir, "best_model") | |
| if hasattr(model, "from_pretrained"): | |
| model = model.__class__.from_pretrained(best_model_path) | |
| else: | |
| model = model.load_state_dict( | |
| torch.load(os.path.join(best_model_path, "pytorch_model.bin")) | |
| ) | |
| if self.training_args.save_last: | |
| self._save_model_checkpoint(model, tokenizer, last=True) | |
| self.model_wrapper.model = model | |
| self._write_readme(best_eval_score, best_eval_score_epoch, train_batch_size) | |
| def evaluate(self): | |
| """Evaluate the model on given evaluation dataset.""" | |
| if not self.eval_dataset: | |
| raise ValueError("No `eval_dataset` available for training.") | |
| logging.info("Evaluating model on evaluation dataset.") | |
| model = self.model_wrapper.model | |
| tokenizer = self.model_wrapper.tokenizer | |
| model.eval() | |
| all_preds = [] | |
| all_targets = [] | |
| if isinstance(model, torch.nn.DataParallel): | |
| num_gpus = torch.cuda.device_count() | |
| eval_batch_size = self.training_args.per_device_eval_batch_size * num_gpus | |
| else: | |
| eval_batch_size = self.training_args.per_device_eval_batch_size | |
| eval_dataloader = self.get_eval_dataloader(self.eval_dataset, eval_batch_size) | |
| with torch.no_grad(): | |
| for step, batch in enumerate(eval_dataloader): | |
| preds, targets = self.evaluate_step(model, tokenizer, batch) | |
| all_preds.append(preds) | |
| all_targets.append(targets) | |
| preds = torch.cat(all_preds) | |
| targets = torch.cat(all_targets) | |
| if self.task_type == "regression": | |
| pearson_correlation, pearson_p_value = scipy.stats.pearsonr(preds, targets) | |
| eval_score = pearson_correlation | |
| else: | |
| correct_predictions = (preds == targets).sum().item() | |
| accuracy = correct_predictions / len(targets) | |
| eval_score = accuracy | |
| if self._metric_name == "accuracy": | |
| logger.info(f"Eval {self._metric_name}: {eval_score*100:.2f}%") | |
| else: | |
| logger.info(f"Eval {self._metric_name}: {eval_score:.4f}%") | |
| return eval_score | |
| def _write_readme(self, best_eval_score, best_eval_score_epoch, train_batch_size): | |
| if isinstance(self.training_args, CommandLineTrainingArgs): | |
| model_name = self.training_args.model_name_or_path | |
| elif isinstance(self.model_wrapper.model, transformers.PreTrainedModel): | |
| if ( | |
| hasattr(self.model_wrapper.model.config, "_name_or_path") | |
| and self.model_wrapper.model.config._name_or_path in HUGGINGFACE_MODELS | |
| ): | |
| # TODO Better way than just checking HUGGINGFACE_MODELS ? | |
| model_name = self.model_wrapper.model.config._name_or_path | |
| elif hasattr(self.model_wrapper.model.config, "model_type"): | |
| model_name = self.model_wrapper.model.config.model_type | |
| else: | |
| model_name = "" | |
| else: | |
| model_name = "" | |
| if model_name: | |
| model_name = f"`{model_name}`" | |
| if ( | |
| isinstance(self.training_args, CommandLineTrainingArgs) | |
| and self.training_args.model_max_length | |
| ): | |
| model_max_length = self.training_args.model_max_length | |
| elif isinstance( | |
| self.model_wrapper.model, | |
| ( | |
| transformers.PreTrainedModel, | |
| LSTMForClassification, | |
| WordCNNForClassification, | |
| ), | |
| ): | |
| model_max_length = self.model_wrapper.tokenizer.model_max_length | |
| else: | |
| model_max_length = None | |
| if model_max_length: | |
| model_max_length_str = f" a maximum sequence length of {model_max_length}," | |
| else: | |
| model_max_length_str = "" | |
| if isinstance( | |
| self.train_dataset, textattack.datasets.HuggingFaceDataset | |
| ) and hasattr(self.train_dataset, "_name"): | |
| dataset_name = self.train_dataset._name | |
| if hasattr(self.train_dataset, "_subset"): | |
| dataset_name += f" ({self.train_dataset._subset})" | |
| elif isinstance( | |
| self.eval_dataset, textattack.datasets.HuggingFaceDataset | |
| ) and hasattr(self.eval_dataset, "_name"): | |
| dataset_name = self.eval_dataset._name | |
| if hasattr(self.eval_dataset, "_subset"): | |
| dataset_name += f" ({self.eval_dataset._subset})" | |
| else: | |
| dataset_name = None | |
| if dataset_name: | |
| dataset_str = ( | |
| "and the `{dataset_name}` dataset loaded using the `datasets` library" | |
| ) | |
| else: | |
| dataset_str = "" | |
| loss_func = ( | |
| "mean squared error" if self.task_type == "regression" else "cross-entropy" | |
| ) | |
| metric_name = ( | |
| "pearson correlation" if self.task_type == "regression" else "accuracy" | |
| ) | |
| epoch_info = f"{best_eval_score_epoch} epoch" + ( | |
| "s" if best_eval_score_epoch > 1 else "" | |
| ) | |
| readme_text = f""" | |
| ## TextAttack Model Card | |
| This {model_name} model was fine-tuned using TextAttack{dataset_str}. The model was fine-tuned | |
| for {self.training_args.num_epochs} epochs with a batch size of {train_batch_size}, | |
| {model_max_length_str} and an initial learning rate of {self.training_args.learning_rate}. | |
| Since this was a {self.task_type} task, the model was trained with a {loss_func} loss function. | |
| The best score the model achieved on this task was {best_eval_score}, as measured by the | |
| eval set {metric_name}, found after {epoch_info}. | |
| For more information, check out [TextAttack on Github](https://github.com/QData/TextAttack). | |
| """ | |
| readme_save_path = os.path.join(self.training_args.output_dir, "README.md") | |
| with open(readme_save_path, "w", encoding="utf-8") as f: | |
| f.write(readme_text.strip() + "\n") | |
| logger.info(f"Wrote README to {readme_save_path}.") | |