from transformers import ConditionalDetrImageProcessor, TrOCRProcessor, ViTImageProcessor import torch from typing import List, Dict, Any, Optional, Tuple from shapely.geometry import box from shapely.geometry.polygon import Polygon from .utils import x1y1x2y2_to_xywh import numpy as np from numpy.typing import NDArray class Magiv2Processor(): """ Procesor danych dla modelu Magiv2 - obsługuje preprocessing i postprocessing. Klasa odpowiedzialna za przygotowanie danych wejściowych dla różnych modułów Magiv2 (detekcja, OCR, embeddingi) oraz przetwarzanie outputów. Zawiera również metody pomocnicze do filtrowania detekcji i konwersji formatów anotacji. Attributes: config: Konfiguracja modelu Magiv2 detection_image_preprocessor: Preprocessor dla obrazów do detekcji obiektów ocr_preprocessor: Preprocessor dla obrazów do OCR crop_embedding_image_preprocessor: Preprocessor dla wyciętych fragmentów obrazu """ def __init__(self, config: Any) -> None: """ Inicjalizuje procesor z podaną konfiguracją. Tworzy preprocessory dla modułów, które są aktywne zgodnie z konfiguracją: - Detekcja obiektów: ConditionalDetrImageProcessor - OCR: TrOCRProcessor - Embeddingi crops: ViTImageProcessor Args: config: Obiekt konfiguracji Magiv2Config z parametrami preprocessingu """ self.config: Any = config self.detection_image_preprocessor: Optional[ConditionalDetrImageProcessor] = None self.ocr_preprocessor: Optional[TrOCRProcessor] = None self.crop_embedding_image_preprocessor: Optional[ViTImageProcessor] = None # Inicjalizacja preprocessora dla detekcji obiektów (jeśli aktywny) if not config.disable_detections: assert config.detection_image_preprocessing_config is not None self.detection_image_preprocessor = ConditionalDetrImageProcessor.from_dict( config.detection_image_preprocessing_config) # Inicjalizacja preprocessora dla OCR (jeśli aktywny) if not config.disable_ocr: assert config.ocr_pretrained_processor_path is not None self.ocr_preprocessor = TrOCRProcessor.from_pretrained( config.ocr_pretrained_processor_path) # Inicjalizacja preprocessora dla embeddingów crops (jeśli aktywny) if not config.disable_crop_embeddings: assert config.crop_embedding_image_preprocessing_config is not None self.crop_embedding_image_preprocessor = ViTImageProcessor.from_dict( config.crop_embedding_image_preprocessing_config) def preprocess_inputs_for_detection( self, images: List[NDArray[np.uint8]], annotations: Optional[List[Dict[str, Any]]] = None ) -> Dict[str, torch.Tensor]: """ Preprocessuje obrazy do formatu wymaganego przez moduł detekcji obiektów. Wykonuje normalizację, resize i padding obrazów. Jeśli podano anotacje, konwertuje je do formatu COCO i skaluje współrzędnie bbox zgodnie z resize. Args: images: Lista obrazów jako numpy arrays (format HWC) annotations: Opcjonalne anotacje ground truth w formacie: [{"image_id": int, "bboxes_as_x1y1x2y2": List, "labels": List}] Returns: Słownik z kluczami: - "pixel_values": torch.Tensor z preprocessowanymi obrazami - "pixel_mask": torch.Tensor z maską paddingu - "labels": List[Dict] z przetworzonymi anotacjami (jeśli podano) """ images_list: List[NDArray[np.uint8]] = list(images) assert isinstance(images_list[0], np.ndarray) # Konwersja anotacji do formatu COCO (bbox w formacie xywh zamiast x1y1x2y2) coco_annotations: Optional[List[Dict[str, Any]] ] = self._convert_annotations_to_coco_format(annotations) # Preprocessing obrazów i anotacji inputs: Dict[str, torch.Tensor] = self.detection_image_preprocessor( images_list, annotations=coco_annotations, return_tensors="pt") return inputs def preprocess_inputs_for_ocr(self, images: List[NDArray[np.uint8]]) -> torch.Tensor: """ Preprocessuje obrazy do formatu wymaganego przez moduł OCR. Wykonuje normalizację i resize obrazów tekstowych dla modelu TrOCR. Args: images: Lista obrazów jako numpy arrays (fragmenty z tekstem) Returns: Tensor z preprocessowanymi obrazami [batch, channels, height, width] """ images_list: List[NDArray[np.uint8]] = list(images) assert isinstance(images_list[0], np.ndarray) return self.ocr_preprocessor(images_list, return_tensors="pt").pixel_values def preprocess_inputs_for_crop_embeddings(self, images: List[NDArray[np.uint8]]) -> torch.Tensor: """ Preprocessuje wycięte fragmenty obrazów dla modułu embeddingów. Wykonuje normalizację i resize crops dla modelu ViT-MAE. Args: images: Lista wyciętych fragmentów obrazów jako numpy arrays Returns: Tensor z preprocessowanymi crops [batch, channels, height, width] """ images_list: List[NDArray[np.uint8]] = list(images) assert isinstance(images_list[0], np.ndarray) return self.crop_embedding_image_preprocessor(images_list, return_tensors="pt").pixel_values def postprocess_ocr_tokens( self, generated_ids: torch.Tensor, skip_special_tokens: bool = True ) -> List[str]: """ Dekoduje tokeny wygenerowane przez model OCR na tekst. Args: generated_ids: Tensor z ID tokenów wygenerowanych przez decoder OCR skip_special_tokens: Czy pomijać specjalne tokeny (PAD, BOS, EOS) w wyniku Returns: Lista stringów z rozpoznanym tekstem """ return self.ocr_preprocessor.batch_decode(generated_ids, skip_special_tokens=skip_special_tokens) def crop_image( self, image: NDArray[np.uint8], bboxes: List[List[float]] ) -> List[NDArray[np.uint8]]: """ Wycina fragmenty obrazu zgodnie z podanymi bounding boxami. Metoda automatycznie naprawia nieprawidłowe bounding boxy: - Ogranicza współrzędne do granic obrazu - Zapewnia minimalny rozmiar 10x10 pikseli - Zamienia współrzędne jeśli są w nieprawidłowej kolejności Args: image: Obraz źródłowy jako numpy array (format HWC) bboxes: Lista bounding boxów w formacie [x1, y1, x2, y2] Returns: Lista wyciętych fragmentów obrazu (każdy jako numpy array) """ crops_for_image: List[NDArray[np.uint8]] = [] for bbox in bboxes: x1: float y1: float x2: float y2: float x1, y1, x2, y2 = bbox # Naprawa bounding boxa w przypadku gdy jest poza granicami lub za mały # Konwersja do int x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) # Upewnienie się że x1 10: x2 = x1 + 10 else: x1 = x2 - 10 # Zapewnienie minimalnej wysokości 10 pikseli if y2 - y1 < 10: if image.shape[0] - y1 > 10: y2 = y1 + 10 else: y1 = y2 - 10 # Wycięcie fragmentu obrazu crop: NDArray[np.uint8] = image[y1:y2, x1:x2] crops_for_image.append(crop) return crops_for_image def _get_indices_of_characters_to_keep( self, batch_scores: torch.Tensor, batch_labels: torch.Tensor, batch_bboxes: torch.Tensor, character_detection_threshold: float ) -> List[torch.Tensor]: """ Filtruje detekcje postaci na podstawie progu prawdopodobieństwa. Zachowuje tylko detekcje z etykietą 0 (postać) i score powyżej progu. Args: batch_scores: Tensor ze scorami prawdopodobieństwa [batch, num_queries] batch_labels: Tensor z etykietami klas [batch, num_queries] batch_bboxes: Tensor z bounding boxami [batch, num_queries, 4] character_detection_threshold: Minimalny score do zachowania detekcji (0-1) Returns: Lista tensorów z indeksami postaci do zachowania dla każdego obrazu """ indices_of_characters_to_keep: List[torch.Tensor] = [] for scores, labels, _ in zip(batch_scores, batch_labels, batch_bboxes): # Filtrowanie: label=0 (postać) AND score > próg indices: torch.Tensor = torch.where((labels == 0) & ( scores > character_detection_threshold))[0] indices_of_characters_to_keep.append(indices) return indices_of_characters_to_keep def _get_indices_of_panels_to_keep( self, batch_scores: torch.Tensor, batch_labels: torch.Tensor, batch_bboxes: torch.Tensor, panel_detection_threshold: float ) -> List[List[int]]: """ Filtruje detekcje paneli z zastosowaniem NMS (Non-Maximum Suppression). Zachowuje tylko panele z etykietą 2 i score powyżej progu. Dodatkowo stosuje NMS aby usunąć nakładające się panele - jeśli nowy panel pokrywa się w >50% z już zaakceptowanymi panelami, jest odrzucany. Args: batch_scores: Tensor ze scorami [batch, num_queries] batch_labels: Tensor z etykietami [batch, num_queries] batch_bboxes: Tensor z bboxami [batch, num_queries, 4] panel_detection_threshold: Minimalny score do zachowania panelu Returns: Lista list indeksów paneli do zachowania (po NMS) dla każdego obrazu """ indices_of_panels_to_keep: List[List[int]] = [] for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): # Wybranie tylko detekcji z label=2 (panel) indices: torch.Tensor = torch.where(labels == 2)[0] bboxes = bboxes[indices] scores = scores[indices] labels = labels[indices] if len(indices) == 0: indices_of_panels_to_keep.append([]) continue # Sortowanie paneli malejąco po score (najlepsze pierwsze) scores, labels, indices, bboxes = zip( *sorted(zip(scores, labels, indices, bboxes), reverse=True)) panels_to_keep: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] # Unia wszystkich zaakceptowanych paneli (do sprawdzania nakładania) union_of_panels_so_far: Polygon = box(0, 0, 0, 0) for ps, pb, pl, pi in zip(scores, bboxes, labels, indices): # Konwersja bbox na polygon Shapely panel_polygon: Polygon = box(pb[0], pb[1], pb[2], pb[3]) # Odrzuć jeśli score poniżej progu if ps < panel_detection_threshold: continue # Odrzuć jeśli panel nakłada się >50% z już zaakceptowanymi panelami (NMS) if union_of_panels_so_far.intersection(panel_polygon).area / panel_polygon.area > 0.5: continue # Zaakceptuj panel panels_to_keep.append((ps, pl, pb, pi)) # Dodaj do unii zaakceptowanych paneli union_of_panels_so_far = union_of_panels_so_far.union( panel_polygon) # Wyciągnięcie indeksów zaakceptowanych paneli indices_of_panels_to_keep.append( [p[3].item() for p in panels_to_keep]) return indices_of_panels_to_keep def _get_indices_of_texts_to_keep( self, batch_scores: torch.Tensor, batch_labels: torch.Tensor, batch_bboxes: torch.Tensor, text_detection_threshold: float ) -> List[List[int]]: """ Filtruje detekcje tekstu z zastosowaniem NMS (Non-Maximum Suppression). Zachowuje tylko tekst z etykietą 1 i score powyżej progu. Stosuje NMS aby usunąć duplikaty - jeśli nowy tekst ma IoU >0.5 z już zaakceptowanym tekstem, jest odrzucany. Args: batch_scores: Tensor ze scorami [batch, num_queries] batch_labels: Tensor z etykietami [batch, num_queries] batch_bboxes: Tensor z bboxami [batch, num_queries, 4] text_detection_threshold: Minimalny score do zachowania tekstu Returns: Lista list indeksów tekstów do zachowania (po NMS) dla każdego obrazu """ indices_of_texts_to_keep: List[List[int]] = [] for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): # Filtrowanie: label=1 (tekst) AND score > próg indices: torch.Tensor = torch.where((labels == 1) & ( scores > text_detection_threshold))[0] bboxes = bboxes[indices] scores = scores[indices] labels = labels[indices] if len(indices) == 0: indices_of_texts_to_keep.append([]) continue # Sortowanie tekstów malejąco po score (najlepsze pierwsze) scores, labels, indices, bboxes = zip( *sorted(zip(scores, labels, indices, bboxes), reverse=True)) texts_to_keep: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] # Lista polygonów zaakceptowanych tekstów (do sprawdzania nakładania) texts_to_keep_as_shapely_objects: List[Polygon] = [] for ts, tb, tl, ti in zip(scores, bboxes, labels, indices): # Konwersja bbox na polygon Shapely text_polygon: Polygon = box(tb[0], tb[1], tb[2], tb[3]) should_append: bool = True # Sprawdź nakładanie z już zaakceptowanymi tekstami for t in texts_to_keep_as_shapely_objects: # Jeśli IoU > 0.5, odrzuć (to duplikat) if t.intersection(text_polygon).area / t.union(text_polygon).area > 0.5: should_append = False break if should_append: texts_to_keep.append((ts, tl, tb, ti)) texts_to_keep_as_shapely_objects.append(text_polygon) # Wyciągnięcie indeksów zaakceptowanych tekstów indices_of_texts_to_keep.append( [t[3].item() for t in texts_to_keep]) return indices_of_texts_to_keep def _get_indices_of_tails_to_keep( self, batch_scores: torch.Tensor, batch_labels: torch.Tensor, batch_bboxes: torch.Tensor, text_detection_threshold: float ) -> List[List[int]]: """ Filtruje detekcje ogonów dymków z zastosowaniem NMS (Non-Maximum Suppression). Zachowuje tylko ogony z etykietą 3 i score powyżej progu. Stosuje NMS aby usunąć duplikaty - jeśli nowy ogon ma IoU >0.5 z już zaakceptowanym ogonem, jest odrzucany. Args: batch_scores: Tensor ze scorami [batch, num_queries] batch_labels: Tensor z etykietami [batch, num_queries] batch_bboxes: Tensor z bboxami [batch, num_queries, 4] text_detection_threshold: Minimalny score do zachowania ogona Returns: Lista list indeksów ogonów do zachowania (po NMS) dla każdego obrazu """ indices_of_tails_to_keep: List[List[int]] = [] for scores, labels, bboxes in zip(batch_scores, batch_labels, batch_bboxes): # Filtrowanie: label=3 (ogon dymku) AND score > próg indices: torch.Tensor = torch.where((labels == 3) & ( scores > text_detection_threshold))[0] bboxes = bboxes[indices] scores = scores[indices] labels = labels[indices] if len(indices) == 0: indices_of_tails_to_keep.append([]) continue # Sortowanie ogonów malejąco po score (najlepsze pierwsze) scores, labels, indices, bboxes = zip( *sorted(zip(scores, labels, indices, bboxes), reverse=True)) tails_to_keep: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = [] # Lista polygonów zaakceptowanych ogonów (do sprawdzania nakładania) tails_to_keep_as_shapely_objects: List[Polygon] = [] for ts, tb, tl, ti in zip(scores, bboxes, labels, indices): # Konwersja bbox na polygon Shapely tail_polygon: Polygon = box(tb[0], tb[1], tb[2], tb[3]) should_append: bool = True # Sprawdź nakładanie z już zaakceptowanymi ogonami for t in tails_to_keep_as_shapely_objects: # Jeśli IoU > 0.5, odrzuć (to duplikat) if t.intersection(tail_polygon).area / t.union(tail_polygon).area > 0.5: should_append = False break if should_append: tails_to_keep.append((ts, tl, tb, ti)) tails_to_keep_as_shapely_objects.append(tail_polygon) # Wyciągnięcie indeksów zaakceptowanych ogonów indices_of_tails_to_keep.append( [t[3].item() for t in tails_to_keep]) return indices_of_tails_to_keep def _convert_annotations_to_coco_format( self, annotations: Optional[List[Dict[str, Any]]] ) -> Optional[List[Dict[str, Any]]]: """ Konwertuje anotacje z formatu x1y1x2y2 do formatu COCO (xywh). Format COCO używa bbox jako [x, y, width, height] zamiast [x1, y1, x2, y2]. Dodatkowo oblicza pole powierzchni dla każdego bbox. Args: annotations: Lista anotacji w formacie: [{"image_id": int, "bboxes_as_x1y1x2y2": List, "labels": List}] lub None Returns: Lista anotacji w formacie COCO lub None jeśli input był None """ if annotations is None: return None # Weryfikacja poprawności formatu anotacji self._verify_annotations_are_in_correct_format(annotations) coco_annotations: List[Dict[str, Any]] = [] for annotation in annotations: coco_annotation: Dict[str, Any] = { "image_id": annotation["image_id"], "annotations": [], } # Konwersja każdego bbox z x1y1x2y2 na xywh for bbox, label in zip(annotation["bboxes_as_x1y1x2y2"], annotation["labels"]): coco_annotation["annotations"].append({ # [x1,y1,x2,y2] -> [x,y,w,h] "bbox": x1y1x2y2_to_xywh(bbox), "category_id": label, # width * height "area": (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]), }) coco_annotations.append(coco_annotation) return coco_annotations def _verify_annotations_are_in_correct_format(self, annotations: Optional[List[Dict[str, Any]]]) -> None: """ Weryfikuje poprawność formatu anotacji. Sprawdza czy anotacje są w oczekiwanym formacie: - Lista/tupla słowników - Każdy słownik zawiera klucze: "image_id", "bboxes_as_x1y1x2y2", "labels" - Labels: 0=postać, 1=tekst, 2=panel, 3=ogon Args: annotations: Anotacje do weryfikacji lub None Raises: ValueError: Jeśli format anotacji jest nieprawidłowy """ error_msg: str = """ Annotations must be in the following format: [ { "image_id": 0, "bboxes_as_x1y1x2y2": [[0, 0, 10, 10], [10, 10, 20, 20], [20, 20, 30, 30]], "labels": [0, 1, 2], }, ... ] Labels: 0 for characters, 1 for text, 2 for panels, 3 for tails. """ if annotations is None: return # Sprawdzenie czy to lista lub tupla if not isinstance(annotations, List) and not isinstance(annotations, tuple): raise ValueError( f"{error_msg} Expected a List/Tuple, found {type(annotations)}." ) if len(annotations) == 0: return # Sprawdzenie czy elementy to słowniki if not isinstance(annotations[0], dict): raise ValueError( f"{error_msg} Expected a List[Dict], found {type(annotations[0])}." ) # Sprawdzenie wymaganych kluczy w słowniku if "image_id" not in annotations[0]: raise ValueError( f"{error_msg} Dict must contain 'image_id'." ) if "bboxes_as_x1y1x2y2" not in annotations[0]: raise ValueError( f"{error_msg} Dict must contain 'bboxes_as_x1y1x2y2'." ) if "labels" not in annotations[0]: raise ValueError( f"{error_msg} Dict must contain 'labels'." )