Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import transformers | |
| from transformers import Trainer | |
| from xtuner.apis import DefaultTrainingArguments, build_qlora_model | |
| from xtuner.apis.datasets import alpaca_data_collator, alpaca_dataset | |
| def train(): | |
| # get DefaultTrainingArguments and to be updated with passed args | |
| parser = transformers.HfArgumentParser(DefaultTrainingArguments) | |
| training_args = parser.parse_args_into_dataclasses()[0] | |
| # init model and dataset | |
| model, tokenizer = build_qlora_model( | |
| model_name_or_path=training_args.model_name_or_path, | |
| return_tokenizer=True) | |
| train_dataset = alpaca_dataset( | |
| tokenizer=tokenizer, path=training_args.dataset_name_or_path) | |
| data_collator = alpaca_data_collator(return_hf_format=True) | |
| # build trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| data_collator=data_collator) | |
| # training | |
| trainer.train() | |
| trainer.save_state() | |
| trainer.save_model(output_dir=training_args.output_dir) | |
| if __name__ == '__main__': | |
| train() | |