Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # coding=utf-8 | |
| # Copyright The HuggingFace Team and The HuggingFace Inc. team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import math | |
| import os | |
| import sys | |
| import datasets | |
| import numpy as np | |
| import torch | |
| import transformers | |
| from aac_metrics import evaluate | |
| from accelerate import Accelerator, DistributedDataParallelKwargs | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import set_seed | |
| from datasets import load_dataset | |
| from omegaconf import OmegaConf | |
| from torch.utils.data import DataLoader | |
| from tqdm.auto import tqdm | |
| from transformers import ( | |
| AutoTokenizer, | |
| BartConfig, | |
| get_inverse_sqrt_schedule, | |
| get_scheduler, | |
| ) | |
| from data.collator import DataCollatorForEnClapBart | |
| from data.preprocess import Preprocessor | |
| from modeling.enclap_bart import EnClapBartForConditionalGeneration | |
| logger = get_logger(__name__) | |
| metric_list = ["meteor", "spider"] | |
| def main(): | |
| # Load Configuration | |
| cfg_path = sys.argv[1] | |
| args = OmegaConf.load(cfg_path) | |
| # Initialize Logging | |
| accelerator_log_kwargs = {} | |
| ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) | |
| if args.with_tracking: | |
| accelerator_log_kwargs["log_with"] = args.report_to | |
| accelerator_log_kwargs["project_dir"] = args.output_dir | |
| # Initialize Accelerator | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| split_batches=args.split_batches, | |
| kwargs_handlers=[ddp_kwargs], | |
| **accelerator_log_kwargs, | |
| ) | |
| # Handle the repository creation | |
| if accelerator.is_main_process: | |
| if args.output_dir is not None: | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| with open(os.path.join(args.output_dir, "args.yaml"), "w") as f: | |
| OmegaConf.save(args, f) | |
| accelerator.wait_for_everyone() | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| file_handler = logging.FileHandler(os.path.join(args.output_dir, "train_log.txt")) | |
| logger.logger.addHandler(file_handler) | |
| logger.info(accelerator.state, main_process_only=False) | |
| if accelerator.is_local_main_process: | |
| datasets.utils.logging.set_verbosity_warning() | |
| transformers.utils.logging.set_verbosity_warning() | |
| else: | |
| datasets.utils.logging.set_verbosity_error() | |
| transformers.utils.logging.set_verbosity_error() | |
| # If passed along, set the training seed now. | |
| if args.seed is not None: | |
| set_seed(args.seed) | |
| # Get the datasets | |
| data_files = {} | |
| data_files_eval = {} | |
| if args.train_file is not None: | |
| data_files["train"] = args.train_file | |
| if args.validation_file is not None: | |
| data_files_eval["validation"] = args.validation_file | |
| extension = args.train_file.split(".")[-1] | |
| raw_datasets = load_dataset(extension, data_files=data_files) | |
| raw_datasets_eval = load_dataset(extension, data_files=data_files_eval) | |
| # Load pretrained model and tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name) | |
| if args.config_name_or_path is not None: | |
| config = BartConfig.from_pretrained(args.config_name_or_path) | |
| else: | |
| config = None | |
| if args.model_name_or_path is not None: | |
| if config is None: | |
| model = EnClapBartForConditionalGeneration.from_pretrained( | |
| args.model_name_or_path | |
| ) | |
| else: | |
| model = EnClapBartForConditionalGeneration.from_pretrained( | |
| args.model_name_or_path, config=config | |
| ) | |
| else: | |
| model = EnClapBartForConditionalGeneration(config=config) | |
| # Set the generation config | |
| if args.val_max_target_length is None: | |
| args.val_max_target_length = args.max_target_length | |
| # Set max encodec length based on the shape of the positional encoding | |
| max_encodec_length = model.config.max_position_embeddings - 2 | |
| label_pad_token_id = ( | |
| -100 if args.ignore_pad_token_for_loss else tokenizer.pad_token_id | |
| ) | |
| preprocessor = Preprocessor( | |
| args.encodec_base_path, | |
| args.clap_base_path, | |
| tokenizer, | |
| model.config.max_position_embeddings, | |
| args.encodec_masking_prob, | |
| args.encodec_masking_span, | |
| label_pad_token_id, | |
| model.config.encodec_vocab_size, | |
| args.eval_num_captions, | |
| ) | |
| with accelerator.main_process_first(): | |
| train_dataset = raw_datasets["train"].map( | |
| preprocessor.preprocess_train, | |
| num_proc=args.preprocessing_num_workers, | |
| load_from_cache_file=not args.overwrite_cache, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| train_dataset.set_format( | |
| "pt", | |
| columns=[ | |
| "input_ids", | |
| "attention_mask", | |
| "clap", | |
| "labels", | |
| "decoder_attention_mask", | |
| ], | |
| ) | |
| # Temporarily set max_target_length for validation. | |
| eval_dataset = raw_datasets_eval["validation"].map( | |
| preprocessor.preprocess_eval, | |
| num_proc=args.preprocessing_num_workers, | |
| load_from_cache_file=not args.overwrite_cache, | |
| desc="Running tokenizer on dataset", | |
| ) | |
| eval_dataset.set_format( | |
| "pt", | |
| columns=["input_ids", "attention_mask", "clap"], | |
| output_all_columns=True, | |
| ) | |
| train_data_collator = DataCollatorForEnClapBart( | |
| tokenizer=tokenizer, | |
| model=model, | |
| return_tensors="pt", | |
| label_pad_token_id=label_pad_token_id, | |
| max_length=max_encodec_length, | |
| encodec_masking_prob=args.encodec_masking_prob, | |
| encodec_masking_span=args.encodec_masking_span, | |
| ) | |
| valid_data_collator = DataCollatorForEnClapBart( | |
| tokenizer=tokenizer, | |
| model=model, | |
| return_tensors="pt", | |
| label_pad_token_id=label_pad_token_id, | |
| max_length=max_encodec_length, | |
| ) | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| shuffle=True, | |
| collate_fn=train_data_collator, | |
| batch_size=args.per_device_train_batch_size, | |
| ) | |
| eval_dataloader = DataLoader( | |
| eval_dataset, | |
| collate_fn=valid_data_collator, | |
| batch_size=args.per_device_eval_batch_size, | |
| ) | |
| # Optimizer | |
| # Split weights in two groups, one with weight decay and the other not. | |
| no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"] | |
| optimizer_grouped_parameters = [ | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.named_parameters() | |
| if not any(nd in n for nd in no_decay) | |
| ], | |
| "weight_decay": args.weight_decay, | |
| }, | |
| { | |
| "params": [ | |
| p | |
| for n, p in model.named_parameters() | |
| if any(nd in n for nd in no_decay) | |
| ], | |
| "weight_decay": 0.0, | |
| }, | |
| ] | |
| optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate) | |
| # Scheduler and math around the number of training steps. | |
| overrode_max_train_steps = False | |
| num_update_steps_per_epoch = math.ceil( | |
| len(train_dataloader) / args.gradient_accumulation_steps | |
| ) | |
| if args.max_train_steps is None: | |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
| overrode_max_train_steps = True | |
| if args.lr_scheduler_type == "inverse_sqrt" and hasattr(args, "time_scale"): | |
| lr_scheduler = get_inverse_sqrt_schedule( | |
| optimizer=optimizer, | |
| num_warmup_steps=args.num_warmup_steps, | |
| timescale=args.time_scale, | |
| ) | |
| else: | |
| lr_scheduler = get_scheduler( | |
| name=args.lr_scheduler_type, | |
| optimizer=optimizer, | |
| num_warmup_steps=args.num_warmup_steps, | |
| num_training_steps=args.max_train_steps, | |
| ) | |
| # Prepare everything with our `accelerator`. | |
| ( | |
| model, | |
| optimizer, | |
| train_dataloader, | |
| eval_dataloader, | |
| lr_scheduler, | |
| ) = accelerator.prepare( | |
| model, optimizer, train_dataloader, eval_dataloader, lr_scheduler | |
| ) | |
| # We need to recalculate our total training steps as the size of the training dataloader may have changed. | |
| num_update_steps_per_epoch = math.ceil( | |
| len(train_dataloader) / args.gradient_accumulation_steps | |
| ) | |
| if overrode_max_train_steps: | |
| args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | |
| # Afterwards we recalculate our number of training epochs | |
| args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) | |
| # Figure out how many steps we should save the Accelerator states | |
| checkpointing_steps = args.checkpointing_steps | |
| if checkpointing_steps is not None and checkpointing_steps.isdigit(): | |
| checkpointing_steps = int(checkpointing_steps) | |
| # The trackers initializes automatically on the main process. | |
| if args.with_tracking: | |
| accelerator.init_trackers(args.logging_dir) | |
| # Train! | |
| total_batch_size = ( | |
| args.per_device_train_batch_size | |
| * accelerator.num_processes | |
| * args.gradient_accumulation_steps | |
| ) | |
| if args.split_batches: | |
| total_batch_size = int(total_batch_size / accelerator.num_processes) | |
| logger.info("***** Running training *****") | |
| logger.info(f" Num examples = {len(train_dataset)}") | |
| logger.info(f" Num Epochs = {args.num_train_epochs}") | |
| logger.info( | |
| f" Instantaneous batch size per device = {args.per_device_train_batch_size}" | |
| ) | |
| logger.info( | |
| f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" | |
| ) | |
| logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | |
| logger.info(f" Total optimization steps = {args.max_train_steps}") | |
| completed_steps = 0 | |
| starting_epoch = 0 | |
| # Potentially load in the weights and states from a previous save | |
| if not args.overwrite_output_dir and os.path.exists( | |
| os.path.join(args.output_dir, "checkpoints") | |
| ): | |
| if args.resume_from_checkpoint is not None: | |
| accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}") | |
| accelerator.load_state(args.resume_from_checkpoint) | |
| path = os.path.basename(args.resume_from_checkpoint) | |
| else: | |
| # Get the most recent checkpoint | |
| dirs = [ | |
| f | |
| for f in os.scandir(os.path.join(args.output_dir, "checkpoints")) | |
| if f.is_dir() | |
| ] | |
| dirs.sort(key=os.path.getctime) | |
| path = dirs[ | |
| -1 | |
| ].name # Sorts folders by date modified, most recent checkpoint is the last | |
| accelerator.print(f"Resumed from checkpoint: {dirs[-1]}") | |
| accelerator.load_state(dirs[-1]) | |
| # Extract `epoch_{i}` or `step_{i}` | |
| training_difference = os.path.splitext(path)[0] | |
| if "epoch" in training_difference: | |
| starting_epoch = int(training_difference.replace("epoch_", "")) + 1 | |
| resume_step = None | |
| completed_steps = starting_epoch * num_update_steps_per_epoch | |
| else: | |
| # need to multiply `gradient_accumulation_steps` to reflect real steps | |
| resume_step = ( | |
| int(training_difference.replace("step_", "")) | |
| * args.gradient_accumulation_steps | |
| ) | |
| starting_epoch = resume_step // len(train_dataloader) | |
| resume_step -= starting_epoch * len(train_dataloader) | |
| completed_steps = resume_step // args.gradient_accumulation_stepp | |
| # update the progress_bar if load from checkpoint | |
| if args.with_tracking: | |
| total_loss = 0 | |
| logging_loss = 0 | |
| before_epoch_loss = 0 | |
| if args.encodec_masking_prob > 0: | |
| total_encodec_loss = 0 | |
| logging_encodec_loss = 0 | |
| before_epoch_encodec_loss = 0 | |
| for epoch in range(starting_epoch, args.num_train_epochs): | |
| model.train() | |
| if ( | |
| args.resume_from_checkpoint | |
| and epoch == starting_epoch | |
| and resume_step is not None | |
| ): | |
| # We skip the first `n` batches in the dataloader when resuming from a checkpoint | |
| active_dataloader = accelerator.skip_first_batches( | |
| train_dataloader, resume_step | |
| ) | |
| else: | |
| active_dataloader = train_dataloader | |
| logger.info(f"***** Running epoch {epoch} *****") | |
| epoch_iterator = tqdm( | |
| active_dataloader, | |
| desc="Training", | |
| disable=not accelerator.is_local_main_process, | |
| dynamic_ncols=True, | |
| colour="CYAN", | |
| ) | |
| for step, batch in enumerate(epoch_iterator): | |
| with accelerator.accumulate(model): | |
| outputs = model(**batch) | |
| loss = outputs.loss | |
| # We keep track of the loss at each epoch | |
| if args.with_tracking: | |
| total_loss += outputs.lm_loss.item() | |
| if args.encodec_masking_prob > 0: | |
| if outputs.encodec_loss is not None: | |
| total_encodec_loss += outputs.encodec_loss.item() | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_( | |
| model.parameters(), max_norm=args.max_grad_norm | |
| ) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad() | |
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| completed_steps += 1 | |
| # Add loss information to tqdm | |
| epoch_iterator.set_postfix(loss=total_loss / completed_steps) | |
| if completed_steps % args.logging_steps == 0: | |
| train_log = { | |
| "train/learning_rate": lr_scheduler.get_last_lr()[0] | |
| } | |
| train_log["train/loss"] = ( | |
| total_loss - logging_loss | |
| ) / args.logging_steps | |
| logging_loss = total_loss | |
| if args.encodec_masking_prob > 0: | |
| train_log["train/encodec_loss"] = ( | |
| total_encodec_loss - logging_encodec_loss | |
| ) / args.logging_steps | |
| logging_encodec_loss = total_encodec_loss | |
| accelerator.log(train_log, step=completed_steps) | |
| if isinstance(checkpointing_steps, int): | |
| if completed_steps % checkpointing_steps == 0: | |
| output_dir = f"step_{completed_steps }" | |
| if args.output_dir is not None: | |
| output_dir = os.path.join( | |
| args.output_dir, "checkpoints", output_dir | |
| ) | |
| accelerator.save_state(output_dir) | |
| if completed_steps >= args.max_train_steps: | |
| break | |
| model.eval() | |
| gen_kwargs = { | |
| "max_length": args.val_max_target_length, | |
| } | |
| predictions = [] | |
| references = [] | |
| eval_iterator = tqdm( | |
| eval_dataloader, | |
| desc="Validation", | |
| disable=not accelerator.is_local_main_process, | |
| dynamic_ncols=True, | |
| colour="MAGENTA", | |
| ) | |
| for step, batch in enumerate(eval_iterator): | |
| # Drop the padded samples of the last batch of dataloader | |
| # try: | |
| # if accelerator.gradient_state.end_of_dataloader and accelerator.gradient_state.remainder > 0: | |
| # batch = batch[:accelerator.gradient_state.remainder] | |
| # except: | |
| # pass | |
| with torch.no_grad(): | |
| batch["input_ids"] = batch["input_ids"].cuda() | |
| batch["clap"] = batch["clap"].cuda() | |
| batch["attention_mask"] = batch["attention_mask"].cuda() | |
| batch["eos_mask"] = batch["eos_mask"].cuda() | |
| generated_tokens = accelerator.unwrap_model(model).generate( | |
| batch["input_ids"], | |
| clap=batch["clap"], | |
| attention_mask=batch["attention_mask"], | |
| eos_mask=batch["eos_mask"], | |
| **gen_kwargs, | |
| ) | |
| generated_tokens = accelerator.pad_across_processes( | |
| generated_tokens, dim=1, pad_index=tokenizer.pad_token_id | |
| ) | |
| generated_tokens = generated_tokens.cpu().numpy() | |
| captions = batch["captions"] | |
| if isinstance(generated_tokens, tuple): | |
| generated_tokens = generated_tokens[0] | |
| decoded_preds = tokenizer.batch_decode( | |
| generated_tokens, skip_special_tokens=True | |
| ) | |
| predictions.extend(decoded_preds) | |
| references.extend(captions) | |
| logger.info("Evaluating predictions...") | |
| result = evaluate(predictions, references, metrics=metric_list) | |
| # Gather Result | |
| result = {k: v.cuda() for k, v in result[0].items()} | |
| result = accelerator.gather_for_metrics(result) | |
| # Log the average of metrics among the processes | |
| if accelerator.num_processes > 1: | |
| result = {f"eval/{k}": round(v.mean().item(), 4) for k, v in result.items()} | |
| else: | |
| result = {f"eval/{k}": round(v.item(), 4) for k, v in result.items()} | |
| logger.info(result) | |
| if args.with_tracking: | |
| result["train/epoch_train_loss"] = (total_loss - before_epoch_loss) / len( | |
| train_dataloader | |
| ) | |
| result["train/steps"] = completed_steps | |
| before_epoch_loss = total_loss | |
| if args.encodec_masking_prob > 0: | |
| result["train/epoch_encodec_loss"] = ( | |
| total_encodec_loss - before_epoch_encodec_loss | |
| ) / len(train_dataloader) | |
| before_epoch_encodec_loss = total_encodec_loss | |
| accelerator.log(result, step=epoch) | |
| if args.checkpointing_steps == "epoch": | |
| output_dir = f"epoch_{epoch}" | |
| if args.output_dir is not None: | |
| output_dir = os.path.join(args.output_dir, "checkpoints", output_dir) | |
| accelerator.save_state(output_dir) | |
| if accelerator.is_main_process: | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.config.save_pretrained(output_dir) | |
| if args.output_dir is not None: | |
| save_dir = os.path.join(args.output_dir, "final") | |
| accelerator.wait_for_everyone() | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.save_pretrained( | |
| save_dir, | |
| is_main_process=accelerator.is_main_process, | |
| save_function=accelerator.save, | |
| ) | |
| if accelerator.is_main_process: | |
| tokenizer.save_pretrained(save_dir) | |
| if __name__ == "__main__": | |
| main() | |