Spaces:
Runtime error
Runtime error
| """ | |
| TrainModelCommand class | |
| ============================== | |
| """ | |
| from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | |
| from textattack import CommandLineTrainingArgs, Trainer | |
| from textattack.commands import TextAttackCommand | |
| class TrainModelCommand(TextAttackCommand): | |
| """The TextAttack train module: | |
| A command line parser to train a model from user specifications. | |
| """ | |
| def run(self, args): | |
| training_args = CommandLineTrainingArgs(**vars(args)) | |
| model_wrapper = CommandLineTrainingArgs._create_model_from_args(training_args) | |
| train_dataset, eval_dataset = CommandLineTrainingArgs._create_dataset_from_args( | |
| training_args | |
| ) | |
| attack = CommandLineTrainingArgs._create_attack_from_args( | |
| training_args, model_wrapper | |
| ) | |
| trainer = Trainer( | |
| model_wrapper, | |
| training_args.task_type, | |
| attack, | |
| train_dataset, | |
| eval_dataset, | |
| training_args, | |
| ) | |
| trainer.train() | |
| def register_subcommand(main_parser: ArgumentParser): | |
| parser = main_parser.add_parser( | |
| "train", | |
| help="train a model for sequence classification", | |
| formatter_class=ArgumentDefaultsHelpFormatter, | |
| ) | |
| parser = CommandLineTrainingArgs._add_parser_args(parser) | |
| parser.set_defaults(func=TrainModelCommand()) | |