Spaces:
Paused
Paused
| import os | |
| import sys | |
| import torch | |
| import hashlib | |
| from typing import Literal, Optional, Tuple | |
| import transformers | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModel, | |
| AutoTokenizer, | |
| HfArgumentParser, | |
| Seq2SeqTrainingArguments | |
| ) | |
| from transformers.utils import check_min_version | |
| from transformers.utils.versions import require_version | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.tokenization_utils import PreTrainedTokenizer | |
| import datasets | |
| from datasets import Dataset, concatenate_datasets, load_dataset | |
| from peft import ( | |
| PeftModel, | |
| TaskType, | |
| LoraConfig, | |
| get_peft_model | |
| ) | |
| from trl import AutoModelForCausalLMWithValueHead | |
| from .config import ( | |
| ModelArguments, | |
| DataTrainingArguments, | |
| FinetuningArguments | |
| ) | |
| from .other import ( | |
| get_logger, | |
| load_trainable_params, | |
| load_valuehead_params, | |
| print_trainable_params, | |
| prepare_model_for_training, | |
| IGNORE_INDEX, | |
| FINETUNING_ARGS_NAME | |
| ) | |
| logger = get_logger(__name__) | |
| check_min_version("4.27.4") | |
| require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") | |
| require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") | |
| require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") | |
| def init_adapter( | |
| model: PreTrainedModel, | |
| model_args: ModelArguments, | |
| finetuning_args: FinetuningArguments, | |
| is_trainable: bool | |
| ) -> PreTrainedModel: | |
| r""" | |
| Initializes the adapters. | |
| Note that the trainable parameters must be cast to float32. | |
| """ | |
| if finetuning_args.finetuning_type == "none" and is_trainable: | |
| raise ValueError("You cannot use finetuning_type=none while training.") | |
| if finetuning_args.finetuning_type == "full": | |
| logger.info("Fine-tuning method: Full") | |
| model = model.float() | |
| if model_args.checkpoint_dir is not None: | |
| load_trainable_params(model, model_args.checkpoint_dir[0]) | |
| if finetuning_args.finetuning_type == "freeze": | |
| logger.info("Fine-tuning method: Freeze") | |
| for name, param in model.named_parameters(): | |
| if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): | |
| param.requires_grad_(False) | |
| else: | |
| param.data = param.data.to(torch.float32) | |
| if model_args.checkpoint_dir is not None: | |
| load_trainable_params(model, model_args.checkpoint_dir[0]) | |
| if finetuning_args.finetuning_type == "p_tuning": | |
| logger.info("Fine-tuning method: P-Tuning v2") # nothing to do | |
| if model_args.checkpoint_dir is not None: | |
| load_trainable_params(model, model_args.checkpoint_dir[0]) | |
| if finetuning_args.finetuning_type == "lora": | |
| logger.info("Fine-tuning method: LoRA") | |
| lastest_checkpoint = None | |
| if model_args.checkpoint_dir is not None: | |
| if is_trainable and finetuning_args.resume_lora_training: # continually training on the lora weights | |
| checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] | |
| else: | |
| checkpoints_to_merge = model_args.checkpoint_dir | |
| for checkpoint in checkpoints_to_merge: | |
| model = PeftModel.from_pretrained(model, checkpoint) | |
| model = model.merge_and_unload() | |
| logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) | |
| if lastest_checkpoint is not None: # resume lora training | |
| model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=True) | |
| if lastest_checkpoint is None: # create new lora weights | |
| lora_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=finetuning_args.lora_rank, | |
| lora_alpha=finetuning_args.lora_alpha, | |
| lora_dropout=finetuning_args.lora_dropout, | |
| target_modules=finetuning_args.lora_target | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| return model | |
| def load_pretrained( | |
| model_args: ModelArguments, | |
| training_args: Optional[Seq2SeqTrainingArguments] = None, | |
| finetuning_args: Optional[FinetuningArguments] = None, | |
| is_trainable: Optional[bool] = False, | |
| stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" | |
| ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: | |
| r""" | |
| Load pretrained model and tokenizer. | |
| """ | |
| if (not is_trainable) and (model_args.checkpoint_dir is None): | |
| logger.warning("Checkpoint is not found at evaluation, load the original model.") | |
| finetuning_args = FinetuningArguments(finetuning_type="none") | |
| if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint | |
| for checkpoint_dir in model_args.checkpoint_dir: | |
| if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): | |
| raise ValueError("The fine-tuning arguments are not found in the provided dictionary.") | |
| logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) | |
| finetuning_args = torch.load(os.path.join(model_args.checkpoint_dir[0], FINETUNING_ARGS_NAME)) | |
| if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1: | |
| logger.warning("Only LoRA tuning accepts multiple checkpoints.") | |
| assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method." | |
| quantization = None | |
| if model_args.quantization_bit is not None: | |
| if is_trainable: | |
| if finetuning_args.finetuning_type == "full": | |
| raise ValueError("Full parameter fine-tuning does not support quantization.") | |
| elif finetuning_args.finetuning_type == "p_tuning": | |
| quantization = "cpm" # use cpm's quantization | |
| else: | |
| quantization = "bnb" # use bnb's quantization | |
| else: | |
| quantization = "cpm" | |
| config_kwargs = { | |
| "trust_remote_code": True, | |
| "cache_dir": model_args.cache_dir, | |
| "revision": model_args.model_revision, | |
| "use_auth_token": True if model_args.use_auth_token else None, | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, | |
| use_fast=model_args.use_fast_tokenizer, | |
| padding_side="left", | |
| **config_kwargs | |
| ) | |
| config = AutoConfig.from_pretrained( | |
| model_args.config_name if model_args.config_name else model_args.model_name_or_path, | |
| **config_kwargs | |
| ) | |
| # P-Tuning v2 configurations. | |
| # We use the built-in p-tuning method of ChatGLM, we cannot use PEFT since the attention masks of ChatGLM are unusual. >_< | |
| if finetuning_args.finetuning_type == "p_tuning": | |
| config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically | |
| config.prefix_projection = finetuning_args.prefix_projection | |
| # Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library). | |
| if quantization == "bnb": | |
| assert model_args.quantization_bit == 8, "Freeze and LoRA fine-tuning only accept 8-bit quantization." | |
| require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.") | |
| from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible | |
| cuda = get_cuda_lib_handle() | |
| cc = get_compute_capability(cuda) | |
| assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization." | |
| config_kwargs["load_in_8bit"] = True | |
| config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit | |
| # Load and prepare pretrained models (without valuehead). | |
| model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, **config_kwargs) | |
| model = prepare_model_for_training(model) if is_trainable else model | |
| model = init_adapter(model, model_args, finetuning_args, is_trainable) | |
| if not is_trainable: | |
| model.requires_grad_(False) # fix all params | |
| model = model.half() # cast all params to float16 | |
| # Quantization with the built-in method for P-Tuning v2 training or evaluation. | |
| # Model parameters should be cast to float16 in quantized P-Tuning setting. | |
| if quantization == "cpm": | |
| assert model_args.quantization_bit in [4, 8], "P-Tuning v2 and inference mode only accept 4-bit or 8-bit quantization." | |
| assert not (is_trainable and training_args.fp16), "FP16 training conflicts with cpm quantization." | |
| model.quantize(model_args.quantization_bit) # in-place method | |
| for name, param in model.named_parameters(): | |
| if "prefix_encoder" not in name: | |
| param.data = param.data.to(torch.float16) # convert all params in half precision except prefix_encoder | |
| if quantization is not None: | |
| logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) | |
| if stage == "rwd" or stage == "ppo": # add value head | |
| assert is_trainable, "Reward and PPO stages cannot be performed at evaluation." | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained(model) | |
| if stage == "ppo": # load reward model | |
| assert model_args.reward_model is not None, "Reward model is necessary for PPO training." | |
| model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) | |
| load_valuehead_params(model, model_args.reward_model) | |
| # Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model | |
| # To meet the compliance requirements of the transformers library | |
| if quantization == "bnb": | |
| model._is_int8_training_enabled = True | |
| print_trainable_params(model) | |
| return model, tokenizer | |
| def prepare_args() -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: | |
| parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) | |
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. | |
| model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) | |
| else: | |
| model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() | |
| # Setup logging | |
| if training_args.should_log: | |
| # The default of training_args.log_level is passive, so we set log level at info here to have that default. | |
| transformers.utils.logging.set_verbosity_info() | |
| log_level = training_args.get_process_log_level() | |
| datasets.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.set_verbosity(log_level) | |
| transformers.utils.logging.enable_default_handler() | |
| transformers.utils.logging.enable_explicit_format() | |
| # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) | |
| if int(training_args.do_train) + int(training_args.do_eval) + int(training_args.do_predict) != 1: | |
| raise ValueError("We must perform a single operation among do_train, do_eval and do_predict.") | |
| if model_args.quantization_bit is not None and training_args.do_train == False: | |
| logger.warning("We do not recommend to evaluaute model in 4/8-bit mode.") | |
| if training_args.do_train and (not training_args.fp16): | |
| logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.") | |
| training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning | |
| # Log on each process the small summary: | |
| logger.warning( | |
| f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" | |
| + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" | |
| ) | |
| logger.info(f"Training/evaluation parameters {training_args}") | |
| # Set seed before initializing model. | |
| transformers.set_seed(training_args.seed) | |
| return model_args, data_args, training_args, finetuning_args | |
| def prepare_data( | |
| model_args: ModelArguments, | |
| data_args: DataTrainingArguments | |
| ) -> Dataset: | |
| def checksum(file_path, hash): | |
| with open(file_path, "rb") as datafile: | |
| binary_data = datafile.read() | |
| sha1 = hashlib.sha1(binary_data).hexdigest() | |
| if sha1 != hash: | |
| logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) | |
| max_samples = data_args.max_samples | |
| all_datasets = [] # support multiple datasets | |
| for dataset_info in data_args.dataset_list: | |
| logger.info("Loading dataset {}...".format(dataset_info)) | |
| if dataset_info.load_from == "hf_hub": | |
| raw_datasets = load_dataset(dataset_info.dataset_name, cache_dir=model_args.cache_dir) | |
| elif dataset_info.load_from == "script": | |
| raw_datasets = load_dataset( | |
| os.path.join(data_args.dataset_dir, dataset_info.dataset_name), | |
| cache_dir=model_args.cache_dir | |
| ) | |
| elif dataset_info.load_from == "file": | |
| data_file = os.path.join(data_args.dataset_dir, dataset_info.file_name) # support json, jsonl and csv | |
| extension = dataset_info.file_name.split(".")[-1] | |
| if dataset_info.file_sha1 is not None: | |
| checksum(data_file, dataset_info.file_sha1) | |
| else: | |
| logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.") | |
| raw_datasets = load_dataset( | |
| extension, | |
| data_files=data_file, | |
| cache_dir=model_args.cache_dir, | |
| use_auth_token=True if model_args.use_auth_token else None | |
| ) | |
| else: | |
| raise NotImplementedError | |
| dataset = raw_datasets[data_args.split] | |
| if max_samples is not None: | |
| max_samples_temp = min(len(dataset), max_samples) | |
| dataset = dataset.select(range(max_samples_temp)) | |
| dummy_data = [None] * len(dataset) | |
| for column, column_name in [ | |
| ("prompt_column", "prompt"), | |
| ("query_column", "query"), | |
| ("response_column", "response"), | |
| ("history_column", "history") | |
| ]: # every dataset will have 4 columns same as each other | |
| if getattr(dataset_info, column) != column_name: | |
| if getattr(dataset_info, column): | |
| dataset = dataset.rename_column(getattr(dataset_info, column), column_name) | |
| else: # None or empty string | |
| dataset = dataset.add_column(column_name, dummy_data) | |
| all_datasets.append(dataset) | |
| if len(data_args.dataset_list) == 1: | |
| all_datasets = all_datasets[0] | |
| else: | |
| all_datasets = concatenate_datasets(all_datasets) | |
| return all_datasets | |
| def preprocess_data( | |
| dataset: Dataset, | |
| tokenizer: PreTrainedTokenizer, | |
| data_args: DataTrainingArguments, | |
| training_args: Seq2SeqTrainingArguments, | |
| stage: Optional[Literal["sft", "rwd", "ppo"]] = "sft" | |
| ) -> Dataset: | |
| column_names = list(dataset.column_names) | |
| prefix = data_args.source_prefix if data_args.source_prefix is not None else "" | |
| def format_example(examples): # support question with a single answer or multiple answers | |
| for i in range(len(examples["prompt"])): | |
| if examples["prompt"][i] and examples["response"][i]: | |
| query, answer = examples["prompt"][i], examples["response"][i] | |
| if examples["query"][i]: | |
| query += examples["query"][i] | |
| if examples["history"][i]: | |
| prompt = "" | |
| history = examples["history"][i] | |
| for j, (old_query, response) in enumerate(history): | |
| prompt += "[Round {}]\n问:{}\n答:{}\n".format(j, old_query, response) | |
| prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) | |
| else: | |
| prompt = query | |
| prompt = prefix + prompt | |
| yield prompt, answer | |
| def preprocess_function_train(examples): | |
| # build inputs with format `X [gMASK] [BOS] Y [EOS]` and labels with format `[IGNORE] ... [IGNORE] [BOS] Y [EOS]` | |
| model_inputs = {"input_ids": [], "labels": []} | |
| for prompt, answer in format_example(examples): | |
| source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) | |
| target_ids = tokenizer.encode(text=answer, add_special_tokens=False) | |
| if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens | |
| source_ids = source_ids[:data_args.max_source_length - 2] | |
| if len(target_ids) > data_args.max_target_length - 1: # eos token | |
| target_ids = target_ids[:data_args.max_target_length - 1] | |
| input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids) | |
| context_length = input_ids.index(tokenizer.bos_token_id) | |
| labels = [IGNORE_INDEX] * context_length + input_ids[context_length:] | |
| model_inputs["input_ids"].append(input_ids) | |
| model_inputs["labels"].append(labels) | |
| return model_inputs | |
| def preprocess_function_eval(examples): | |
| # build inputs with format `[PAD] ... [PAD] X [gMASK] [BOS]` and labels with format `Y [gMASK] [BOS]` | |
| # left-padding is needed for prediction, use the built-in function of the tokenizer | |
| inputs, targets = [], [] | |
| for prompt, answer in format_example(examples): | |
| inputs.append(prompt) | |
| targets.append(answer) | |
| model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) | |
| labels = tokenizer(text_target=targets, max_length=data_args.max_target_length, truncation=True) # no padding | |
| if data_args.ignore_pad_token_for_loss: | |
| labels["input_ids"] = [ | |
| [(l_id if l_id != tokenizer.pad_token_id else IGNORE_INDEX) for l_id in label] for label in labels["input_ids"] | |
| ] | |
| model_inputs["labels"] = labels["input_ids"] | |
| return model_inputs | |
| def preprocess_function_train_pair(examples): | |
| # build input pairs with format `X [gMASK] [BOS] Y1 [EOS]` and `X [gMASK] [BOS] Y2 [EOS]` | |
| model_inputs = {"accept_ids": [], "reject_ids": []} | |
| for prompt, answer in format_example(examples): | |
| source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) | |
| accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) | |
| reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) | |
| if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens | |
| source_ids = source_ids[:data_args.max_source_length - 2] | |
| if len(accept_ids) > data_args.max_target_length - 1: # eos token | |
| accept_ids = accept_ids[:data_args.max_target_length - 1] | |
| if len(reject_ids) > data_args.max_target_length - 1: # eos token | |
| reject_ids = reject_ids[:data_args.max_target_length - 1] | |
| accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error | |
| reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids) | |
| model_inputs["accept_ids"].append(accept_ids) | |
| model_inputs["reject_ids"].append(reject_ids) | |
| return model_inputs | |
| def preprocess_function_train_ppo(examples): | |
| # build inputs with format `X [gMASK] [BOS]` | |
| model_inputs = {"input_ids": []} | |
| for prompt, _ in format_example(examples): | |
| source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) | |
| if len(source_ids) > data_args.max_source_length - 2: # gmask and bos tokens | |
| source_ids = source_ids[:data_args.max_source_length - 2] | |
| input_ids = tokenizer.build_inputs_with_special_tokens(source_ids) | |
| model_inputs["input_ids"].append(input_ids) | |
| return model_inputs | |
| def print_sft_dataset_example(example): | |
| print("input_ids:\n{}".format(example["input_ids"])) | |
| print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) | |
| print("label_ids:\n{}".format(example["labels"])) | |
| print("labels:\n{}".format(tokenizer.decode(example["labels"]))) | |
| def print_pairwise_dataset_example(example): | |
| print("accept_ids:\n{}".format(example["accept_ids"])) | |
| print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"]))) | |
| print("reject_ids:\n{}".format(example["reject_ids"])) | |
| print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"]))) | |
| def print_ppo_dataset_example(example): | |
| print("input_ids:\n{}".format(example["input_ids"])) | |
| print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) | |
| if stage == "sft": | |
| preprocess_function = preprocess_function_train if training_args.do_train else preprocess_function_eval | |
| elif stage == "rwd": | |
| preprocess_function = preprocess_function_train_pair | |
| elif stage == "ppo": | |
| preprocess_function = preprocess_function_train_ppo | |
| with training_args.main_process_first(desc="dataset map pre-processing"): | |
| dataset = dataset.map( | |
| preprocess_function, | |
| batched=True, | |
| num_proc=data_args.preprocessing_num_workers, | |
| remove_columns=column_names, | |
| load_from_cache_file=not data_args.overwrite_cache, | |
| desc="Running tokenizer on dataset" | |
| ) | |
| if stage == "sft": | |
| print_sft_dataset_example(dataset[0]) | |
| elif stage == "rwd": | |
| print_pairwise_dataset_example(dataset[0]) | |
| elif stage == "ppo": | |
| print_ppo_dataset_example(dataset[0]) | |
| return dataset | |