# Punctuation restoration — loads Oliver Guhr’s model and restores punctuation in raw text import torch from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline # Model MODEL_NAME = "oliverguhr/fullstop-punctuation-multilang-large" DEVICE = 0 if torch.cuda.is_available() else -1 print(f"Loading punctuation model ({MODEL_NAME}) on {'GPU' if DEVICE == 0 else 'CPU'}...") # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME) # pipeline for token classification punctuation_pipeline = pipeline( "token-classification", model=model, tokenizer=tokenizer, device=DEVICE, aggregation_strategy="simple" ) # Main function def punctuate_text(text: str) -> str: """ Restores punctuation in the given text using Oliver Guhr's model. Returns the punctuated text. """ if not text.strip(): return text try: results = punctuation_pipeline(text) punctuated_text = "" for item in results: word = item['word'].replace("▁", " ") label = item['entity_group'] # Map labels to punctuation marks if label == "COMMA": punctuated_text += word + "," elif label == "PERIOD": punctuated_text += word + "." elif label == "QUESTION": punctuated_text += word + "?" else: punctuated_text += word # Clean spacing return " ".join(punctuated_text.split()) except Exception as e: print(f"[punctuate_text] Error: {e}") return text