Spaces:
Running
Running
| import sys | |
| import traceback | |
| from finetrainers import BaseArgs, SFTTrainer, TrainingType, get_logger | |
| from finetrainers.config import _get_model_specifiction_cls | |
| from finetrainers.trainer.sft_trainer.config import SFTFullRankConfig, SFTLowRankConfig | |
| logger = get_logger() | |
| def main(): | |
| try: | |
| import multiprocessing | |
| multiprocessing.set_start_method("fork") | |
| except Exception as e: | |
| logger.error( | |
| f'Failed to set multiprocessing start method to "fork". This can lead to poor performance, high memory usage, or crashes. ' | |
| f"See: https://pytorch.org/docs/stable/notes/multiprocessing.html\n" | |
| f"Error: {e}" | |
| ) | |
| try: | |
| args = BaseArgs() | |
| argv = [y.strip() for x in sys.argv for y in x.split()] | |
| training_type_index = argv.index("--training_type") | |
| if training_type_index == -1: | |
| raise ValueError("Training type not provided in command line arguments.") | |
| training_type = argv[training_type_index + 1] | |
| training_cls = None | |
| if training_type == TrainingType.LORA: | |
| training_cls = SFTLowRankConfig | |
| elif training_type == TrainingType.FULL_FINETUNE: | |
| training_cls = SFTFullRankConfig | |
| else: | |
| raise ValueError(f"Training type {training_type} not supported.") | |
| training_config = training_cls() | |
| args.extend_args(training_config.add_args, training_config.map_args, training_config.validate_args) | |
| args = args.parse_args() | |
| model_specification_cls = _get_model_specifiction_cls(args.model_name, args.training_type) | |
| model_specification = model_specification_cls( | |
| pretrained_model_name_or_path=args.pretrained_model_name_or_path, | |
| tokenizer_id=args.tokenizer_id, | |
| tokenizer_2_id=args.tokenizer_2_id, | |
| tokenizer_3_id=args.tokenizer_3_id, | |
| text_encoder_id=args.text_encoder_id, | |
| text_encoder_2_id=args.text_encoder_2_id, | |
| text_encoder_3_id=args.text_encoder_3_id, | |
| transformer_id=args.transformer_id, | |
| vae_id=args.vae_id, | |
| text_encoder_dtype=args.text_encoder_dtype, | |
| text_encoder_2_dtype=args.text_encoder_2_dtype, | |
| text_encoder_3_dtype=args.text_encoder_3_dtype, | |
| transformer_dtype=args.transformer_dtype, | |
| vae_dtype=args.vae_dtype, | |
| revision=args.revision, | |
| cache_dir=args.cache_dir, | |
| ) | |
| if args.training_type in [TrainingType.LORA, TrainingType.FULL_FINETUNE]: | |
| trainer = SFTTrainer(args, model_specification) | |
| else: | |
| raise ValueError(f"Training type {args.training_type} not supported.") | |
| trainer.run() | |
| except KeyboardInterrupt: | |
| logger.info("Received keyboard interrupt. Exiting...") | |
| except Exception as e: | |
| logger.error(f"An error occurred during training: {e}") | |
| logger.error(traceback.format_exc()) | |
| if __name__ == "__main__": | |
| main() | |