Spaces:
Runtime error
Runtime error
| import string | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| Text2TextGenerationPipeline, | |
| ) | |
| class KeyphraseGenerationPipeline(Text2TextGenerationPipeline): | |
| def __init__(self, model, keyphrase_sep_token=";", *args, **kwargs): | |
| super().__init__( | |
| model=AutoModelForSeq2SeqLM.from_pretrained(model), | |
| tokenizer=AutoTokenizer.from_pretrained(model, truncation=True), | |
| *args, | |
| **kwargs | |
| ) | |
| self.keyphrase_sep_token = keyphrase_sep_token | |
| def postprocess(self, model_outputs): | |
| results = super().postprocess(model_outputs=model_outputs) | |
| return [ | |
| [ | |
| keyphrase.strip().translate(str.maketrans("", "", string.punctuation)) | |
| for keyphrase in result.get("generated_text").split( | |
| self.keyphrase_sep_token | |
| ) | |
| if keyphrase.translate(str.maketrans("", "", string.punctuation)) != "" | |
| ] | |
| for result in results | |
| ][0] | |