Spaces:
Running
Running
| from peft import PeftConfig, PeftModel | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline, RobertaTokenizerFast | |
| import nltk | |
| import re | |
| class CommaFixer: | |
| """ | |
| A wrapper class for the fine-tuned comma fixer model. | |
| """ | |
| def __init__(self, device=-1): | |
| self.id2label = {0: 'O', 1: 'B-COMMA'} | |
| self.label2id = {'O': 0, 'B-COMMA': 1} | |
| self.model, self.tokenizer = self._load_peft_model() | |
| def fix_commas(self, s: str) -> str: | |
| """ | |
| The main method for fixing commas using the fine-tuned model. | |
| In the future we should think about batching the calls to it, for now it processes requests string by string. | |
| :param s: A string with commas to fix, without length restrictions. | |
| However, if the string is longer than the length limit (512 tokens), some whitespaces might be trimmed. | |
| Example: comma_fixer.fix_commas("One two thre, and four!") | |
| :return: A string with commas fixed, example: "One, two, thre and four!" | |
| """ | |
| s_no_commas = re.sub(r'\s*,', '', s) | |
| tokenized = self.tokenizer(s_no_commas, return_tensors='pt', return_offsets_mapping=True, return_length=True) | |
| # If text too long, split into sentences and fix commas separately. | |
| # TODO this is slow, we should think about joining them until length, or maybe a length limit to avoid | |
| # stalling the whole service | |
| if tokenized['length'][0] > self.tokenizer.model_max_length: | |
| return ' '.join(self.fix_commas(sentence) for sentence in nltk.sent_tokenize(s)) | |
| logits = self.model(input_ids=tokenized['input_ids'], attention_mask=tokenized['attention_mask']).logits | |
| labels = [self.id2label[tag_id.item()] for tag_id in logits.argmax(dim=2).flatten()] | |
| return _fix_commas_based_on_labels_and_offsets(labels, s_no_commas, tokenized['offset_mapping'][0]) | |
| def _load_peft_model(self, model_name="klasocki/roberta-large-lora-ner-comma-fixer") -> tuple[ | |
| PeftModel, RobertaTokenizerFast]: | |
| """ | |
| Creates the huggingface model and tokenizer. | |
| Can also be used for pre-downloading the model and the tokenizer. | |
| :param model_name: Name of the model on the huggingface hub. | |
| :return: A model with the peft adapter injected and weights merged, and the tokenizer. | |
| """ | |
| config = PeftConfig.from_pretrained(model_name) | |
| inference_model = AutoModelForTokenClassification.from_pretrained( | |
| config.base_model_name_or_path, num_labels=len(self.id2label), id2label=self.id2label, | |
| label2id=self.label2id | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) | |
| model = PeftModel.from_pretrained(inference_model, model_name) | |
| model = model.merge_and_unload() # Join LoRa matrices with the main model for faster inference | |
| # TODO batch, and move to CUDA if available | |
| return model.eval(), tokenizer | |
| def _fix_commas_based_on_labels_and_offsets( | |
| labels: list[str], | |
| original_s: str, | |
| offset_map: list[tuple[int, int]] | |
| ) -> str: | |
| """ | |
| This function returns the original string with only commas fixed, based on the predicted labels from the main | |
| model and the offsets from the tokenizer. | |
| :param labels: Predicted labels for the tokens. | |
| Should already be converted to string, since we will look for B-COMMA tags. | |
| :param original_s: The original string, used to preserve original spacing and punctuation. | |
| :param offset_map: List of offsets in the original string, we will only use the second integer of each pair | |
| indicating where the token ended originally in the string. | |
| :return: The string with commas fixed, and everything else intact. | |
| """ | |
| result = original_s | |
| commas_inserted = 0 | |
| for i, label in enumerate(labels): | |
| current_offset = offset_map[i][1] + commas_inserted | |
| if _should_insert_comma(label, result, current_offset): | |
| result = result[:current_offset] + ',' + result[current_offset:] | |
| commas_inserted += 1 | |
| return result | |
| def _should_insert_comma(label, result, current_offset) -> bool: | |
| # Only insert commas for the final token of a word, that is, if next word starts with a space. | |
| # TODO perharps for low confidence tokens, we should use the original decision of the user in the input? | |
| return label == 'B-COMMA' and result[current_offset].isspace() | |
| if __name__ == "__main__": | |
| CommaFixer() # to pre-download the model and tokenizer | |