Spaces:
Runtime error
Runtime error
guardrails-genie
/
guardrails_genie
/guardrails
/entity_recognition
/presidio_entity_recognition_guardrail.py
| from typing import List, Dict, Optional, ClassVar, Any | |
| import weave | |
| from pydantic import BaseModel | |
| from presidio_analyzer import AnalyzerEngine, RecognizerRegistry, Pattern, PatternRecognizer | |
| from presidio_anonymizer import AnonymizerEngine | |
| from ..base import Guardrail | |
| class PresidioEntityRecognitionResponse(BaseModel): | |
| contains_entities: bool | |
| detected_entities: Dict[str, List[str]] | |
| explanation: str | |
| anonymized_text: Optional[str] = None | |
| def safe(self) -> bool: | |
| return not self.contains_entities | |
| class PresidioEntityRecognitionSimpleResponse(BaseModel): | |
| contains_entities: bool | |
| explanation: str | |
| anonymized_text: Optional[str] = None | |
| def safe(self) -> bool: | |
| return not self.contains_entities | |
| #TODO: Add support for transformers workflow and not just Spacy | |
| class PresidioEntityRecognitionGuardrail(Guardrail): | |
| def get_available_entities() -> List[str]: | |
| registry = RecognizerRegistry() | |
| analyzer = AnalyzerEngine(registry=registry) | |
| return [recognizer.supported_entities[0] | |
| for recognizer in analyzer.registry.recognizers] | |
| analyzer: AnalyzerEngine | |
| anonymizer: AnonymizerEngine | |
| selected_entities: List[str] | |
| should_anonymize: bool | |
| language: str | |
| def __init__( | |
| self, | |
| selected_entities: Optional[List[str]] = None, | |
| should_anonymize: bool = False, | |
| language: str = "en", | |
| deny_lists: Optional[Dict[str, List[str]]] = None, | |
| regex_patterns: Optional[Dict[str, List[Dict[str, str]]]] = None, | |
| custom_recognizers: Optional[List[Any]] = None, | |
| show_available_entities: bool = False | |
| ): | |
| # If show_available_entities is True, print available entities | |
| if show_available_entities: | |
| available_entities = self.get_available_entities() | |
| print("\nAvailable entities:") | |
| print("=" * 25) | |
| for entity in available_entities: | |
| print(f"- {entity}") | |
| print("=" * 25 + "\n") | |
| # Initialize default values to all available entities | |
| if selected_entities is None: | |
| selected_entities = self.get_available_entities() | |
| # Get available entities dynamically | |
| available_entities = self.get_available_entities() | |
| # Filter out invalid entities and warn user | |
| invalid_entities = [e for e in selected_entities if e not in available_entities] | |
| valid_entities = [e for e in selected_entities if e in available_entities] | |
| if invalid_entities: | |
| print(f"\nWarning: The following entities are not available and will be ignored: {invalid_entities}") | |
| print(f"Continuing with valid entities: {valid_entities}") | |
| selected_entities = valid_entities | |
| # Initialize analyzer with default recognizers | |
| analyzer = AnalyzerEngine() | |
| # Add custom recognizers if provided | |
| if custom_recognizers: | |
| for recognizer in custom_recognizers: | |
| analyzer.registry.add_recognizer(recognizer) | |
| # Add deny list recognizers if provided | |
| if deny_lists: | |
| for entity_type, tokens in deny_lists.items(): | |
| deny_list_recognizer = PatternRecognizer( | |
| supported_entity=entity_type, | |
| deny_list=tokens | |
| ) | |
| analyzer.registry.add_recognizer(deny_list_recognizer) | |
| # Add regex pattern recognizers if provided | |
| if regex_patterns: | |
| for entity_type, patterns in regex_patterns.items(): | |
| presidio_patterns = [ | |
| Pattern( | |
| name=pattern.get("name", f"pattern_{i}"), | |
| regex=pattern["regex"], | |
| score=pattern.get("score", 0.5) | |
| ) for i, pattern in enumerate(patterns) | |
| ] | |
| regex_recognizer = PatternRecognizer( | |
| supported_entity=entity_type, | |
| patterns=presidio_patterns | |
| ) | |
| analyzer.registry.add_recognizer(regex_recognizer) | |
| # Initialize Presidio engines | |
| anonymizer = AnonymizerEngine() | |
| # Call parent class constructor with all fields | |
| super().__init__( | |
| analyzer=analyzer, | |
| anonymizer=anonymizer, | |
| selected_entities=selected_entities, | |
| should_anonymize=should_anonymize, | |
| language=language | |
| ) | |
| def guard(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse: | |
| """ | |
| Check if the input prompt contains any entities using Presidio. | |
| Args: | |
| prompt: The text to analyze | |
| return_detected_types: If True, returns detailed entity type information | |
| """ | |
| # Analyze text for entities | |
| analyzer_results = self.analyzer.analyze( | |
| text=str(prompt), | |
| entities=self.selected_entities, | |
| language=self.language | |
| ) | |
| # Group results by entity type | |
| detected_entities = {} | |
| for result in analyzer_results: | |
| entity_type = result.entity_type | |
| text_slice = prompt[result.start:result.end] | |
| if entity_type not in detected_entities: | |
| detected_entities[entity_type] = [] | |
| detected_entities[entity_type].append(text_slice) | |
| # Create explanation | |
| explanation_parts = [] | |
| if detected_entities: | |
| explanation_parts.append("Found the following entities in the text:") | |
| for entity_type, instances in detected_entities.items(): | |
| explanation_parts.append(f"- {entity_type}: {len(instances)} instance(s)") | |
| else: | |
| explanation_parts.append("No entities detected in the text.") | |
| # Add information about what was checked | |
| explanation_parts.append("\nChecked for these entity types:") | |
| for entity in self.selected_entities: | |
| explanation_parts.append(f"- {entity}") | |
| # Anonymize if requested | |
| anonymized_text = None | |
| if self.should_anonymize and detected_entities: | |
| anonymized_result = self.anonymizer.anonymize( | |
| text=prompt, | |
| analyzer_results=analyzer_results | |
| ) | |
| anonymized_text = anonymized_result.text | |
| if return_detected_types: | |
| return PresidioEntityRecognitionResponse( | |
| contains_entities=bool(detected_entities), | |
| detected_entities=detected_entities, | |
| explanation="\n".join(explanation_parts), | |
| anonymized_text=anonymized_text | |
| ) | |
| else: | |
| return PresidioEntityRecognitionSimpleResponse( | |
| contains_entities=bool(detected_entities), | |
| explanation="\n".join(explanation_parts), | |
| anonymized_text=anonymized_text | |
| ) | |
| def predict(self, prompt: str, return_detected_types: bool = True, **kwargs) -> PresidioEntityRecognitionResponse | PresidioEntityRecognitionSimpleResponse: | |
| return self.guard(prompt, return_detected_types=return_detected_types, **kwargs) |