Mateusz Mróz
Implement Magiv2Model with detection, OCR, and character association capabilities
cd77b9d
| from transformers import PreTrainedModel, VisionEncoderDecoderModel, ViTMAEModel, ConditionalDetrModel | |
| from transformers.models.conditional_detr.modeling_conditional_detr import ( | |
| ConditionalDetrMLPPredictionHead, | |
| ConditionalDetrModelOutput, | |
| inverse_sigmoid, | |
| ) | |
| from .configuration_magiv2 import Magiv2Config | |
| from .processing_magiv2 import Magiv2Processor | |
| from torch import nn | |
| from typing import Optional, List, Callable, Dict, Any, Tuple | |
| import torch | |
| from einops import rearrange, repeat | |
| from .utils import move_to_device, visualise_single_image_prediction, sort_panels, sort_text_boxes_in_reading_order | |
| from transformers.image_transforms import center_to_corners_format | |
| from .utils import UnionFind, sort_panels, sort_text_boxes_in_reading_order | |
| import pulp | |
| import scipy | |
| import numpy as np | |
| from scipy.optimize import linear_sum_assignment | |
| from numpy.typing import NDArray | |
| class Magiv2Model(PreTrainedModel): | |
| """ | |
| Model Magiv2 - wielomodułowy model wizyjny do analizy komiksów/mang. | |
| Model składa się z trzech głównych komponentów (każdy może być opcjonalnie wyłączony): | |
| 1. Moduł detekcji obiektów - wykrywa panele, postaci, tekst, ogony dymków | |
| 2. Moduł OCR - rozpoznaje tekst w wykrytych obszarach tekstowych | |
| 3. Moduł embedowania - tworzy reprezentacje wektorowe dla wyciętych fragmentów obrazu | |
| Dodatkowo model posiada głowice do: | |
| - Predykcji bounding boxów dla wykrytych obiektów | |
| - Dopasowywania postaci do siebie (character-character matching) | |
| - Dopasowywania tekstu do postaci (text-character matching) | |
| - Dopasowywania tekstu do ogonów dymków (text-tail matching) | |
| - Klasyfikacji typu tekstu (czy to dialog) | |
| Attributes: | |
| config_class: Klasa konfiguracji używana przez ten model | |
| config: Instancja konfiguracji modelu | |
| processor: Procesor do preprocessingu danych wejściowych | |
| ocr_model: Model encoder-decoder do rozpoznawania tekstu (opcjonalny) | |
| crop_embedding_model: Model ViT-MAE do tworzenia embeddingów (opcjonalny) | |
| detection_transformer: Transformer do detekcji obiektów (opcjonalny) | |
| bbox_predictor: Głowica MLP do predykcji bounding boxów | |
| character_character_matching_head: Głowica do dopasowywania postaci | |
| text_character_matching_head: Głowica do dopasowywania tekstu do postaci | |
| text_tail_matching_head: Głowica do dopasowywania tekstu do ogonów | |
| class_labels_classifier: Klasyfikator klas obiektów | |
| is_this_text_a_dialogue: Klasyfikator typu tekstu (dialog vs naracja) | |
| matcher: Hungarian matcher do dopasowywania predykcji do targetów | |
| num_non_obj_tokens: Liczba tokenów niebędących obiektami w outputcie transformera | |
| """ | |
| config_class: type[Magiv2Config] = Magiv2Config | |
| def __init__(self, config: Magiv2Config) -> None: | |
| """ | |
| Inicjalizuje model Magiv2 z podaną konfiguracją. | |
| Args: | |
| config: Obiekt konfiguracji typu Magiv2Config zawierający wszystkie | |
| parametry modelu i informacje o tym, które moduły są aktywne. | |
| Returns: | |
| None | |
| """ | |
| super().__init__(config) | |
| self.config: Magiv2Config = config | |
| self.processor: Magiv2Processor = Magiv2Processor(config) | |
| # Inicjalizacja modelu OCR (opcjonalna, zależna od konfiguracji) | |
| if not config.disable_ocr: | |
| self.ocr_model: VisionEncoderDecoderModel = VisionEncoderDecoderModel( | |
| config.ocr_model_config) | |
| # Inicjalizacja modelu embedowania wycięć (opcjonalna, zależna od konfiguracji) | |
| if not config.disable_crop_embeddings: | |
| self.crop_embedding_model: ViTMAEModel = ViTMAEModel( | |
| config.crop_embedding_model_config) | |
| # Inicjalizacja modułu detekcji obiektów i wszystkich powiązanych głowic | |
| if not config.disable_detections: | |
| # Liczba tokenów w outputcie transformera, które nie reprezentują obiektów | |
| # (tokeny specjalne używane do zadań matching) | |
| self.num_non_obj_tokens: int = 5 | |
| # Główny transformer do detekcji obiektów (panele, postaci, tekst, ogony) | |
| self.detection_transformer: ConditionalDetrModel = ConditionalDetrModel( | |
| config.detection_model_config) | |
| # Głowica MLP do predykcji współrzędnych bounding boxów (4 wartości: cx, cy, w, h) | |
| self.bbox_predictor: ConditionalDetrMLPPredictionHead = ConditionalDetrMLPPredictionHead( | |
| input_dim=config.detection_model_config.d_model, | |
| hidden_dim=config.detection_model_config.d_model, | |
| output_dim=4, num_layers=3 | |
| ) | |
| # Głowica do dopasowywania postaci do siebie (clustering postaci) | |
| # Input: tokeny dwóch postaci + token c2c + opcjonalnie embeddingi wycięć | |
| self.character_character_matching_head: ConditionalDetrMLPPredictionHead = ConditionalDetrMLPPredictionHead( | |
| input_dim=3 * config.detection_model_config.d_model + | |
| (2 * config.crop_embedding_model_config.hidden_size if not config.disable_crop_embeddings else 0), | |
| hidden_dim=config.detection_model_config.d_model, | |
| output_dim=1, num_layers=3 | |
| ) | |
| # Głowica do dopasowywania tekstu do postaci (kto mówi) | |
| # Input: token tekstu + token postaci + token t2c | |
| self.text_character_matching_head: ConditionalDetrMLPPredictionHead = ConditionalDetrMLPPredictionHead( | |
| input_dim=3 * config.detection_model_config.d_model, | |
| hidden_dim=config.detection_model_config.d_model, | |
| output_dim=1, num_layers=3 | |
| ) | |
| # Głowica do dopasowywania tekstu do ogonów dymków | |
| # Input: token tekstu + token ogona | |
| self.text_tail_matching_head: ConditionalDetrMLPPredictionHead = ConditionalDetrMLPPredictionHead( | |
| input_dim=2 * config.detection_model_config.d_model, | |
| hidden_dim=config.detection_model_config.d_model, | |
| output_dim=1, num_layers=3 | |
| ) | |
| # Klasyfikator klas dla wykrytych obiektów | |
| # (0=postać, 1=tekst, 2=panel, 3=ogon, etc.) | |
| self.class_labels_classifier: nn.Linear = nn.Linear( | |
| config.detection_model_config.d_model, config.detection_model_config.num_labels | |
| ) | |
| # Klasyfikator binarny: czy dany tekst to dialog (vs naracja/sound effect) | |
| self.is_this_text_a_dialogue: nn.Linear = nn.Linear( | |
| config.detection_model_config.d_model, 1 | |
| ) | |
| # Hungarian matcher do dopasowywania predykcji do ground truth podczas treningu | |
| self.matcher: ConditionalDetrHungarianMatcher = ConditionalDetrHungarianMatcher( | |
| class_cost=config.detection_model_config.class_cost, | |
| bbox_cost=config.detection_model_config.bbox_cost, | |
| giou_cost=config.detection_model_config.giou_cost | |
| ) | |
| def move_to_device(self, input: Any) -> Any: | |
| """ | |
| Przenosi dane wejściowe na to samo urządzenie co model. | |
| Args: | |
| input: Dane do przeniesienia (tensor, dict, lista, etc.) | |
| Returns: | |
| Dane przeniesione na urządzenie modelu | |
| """ | |
| return move_to_device(input, self.device) | |
| def do_chapter_wide_prediction( | |
| self, | |
| pages_in_order: List[NDArray[np.uint8]], | |
| character_bank: Dict[str, Any], | |
| eta: float = 0.75, | |
| batch_size: int = 8, | |
| use_tqdm: bool = False, | |
| do_ocr: bool = True | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Wykonuje kompleksową predykcję dla całego rozdziału komiksu/mangi. | |
| Ta metoda przeprowadza pełną analizę wszystkich stron w rozdziale, obejmującą: | |
| 1. Detekcję obiektów (panele, postaci, tekst, ogony dymków) na każdej stronie | |
| 2. Dopasowywanie postaci do siebie w obrębie strony i między stronami | |
| 3. Przypisywanie imion postaci na podstawie banku znanych postaci | |
| 4. Rozpoznawanie tekstu (OCR) w wykrytych obszarach tekstowych | |
| Args: | |
| pages_in_order: Lista obrazów stron w kolejności (każdy obraz jako numpy array) | |
| character_bank: Słownik zawierający bazę znanych postaci: | |
| - "images": lista obrazów referencyjnych postaci | |
| - "names": lista imion odpowiadających obrazom | |
| eta: Parametr kosztu dla opcji "inne postaci" w dopasowywaniu (0-1). | |
| Wyższe wartości zwiększają prawdopodobieństwo przypisania "Other". | |
| batch_size: Rozmiar batcha dla przetwarzania stron (kompromis pamięć/prędkość) | |
| use_tqdm: Czy wyświetlać pasek postępu podczas przetwarzania | |
| do_ocr: Czy wykonać rozpoznawanie tekstu (OCR) na wykrytych obszarach | |
| Returns: | |
| Lista słowników, jeden dla każdej strony, zawierających: | |
| - "panels": lista bounding boxów paneli | |
| - "texts": lista bounding boxów tekstu | |
| - "characters": lista bounding boxów postaci | |
| - "tails": lista bounding boxów ogonów dymków | |
| - "text_character_associations": asocjacje tekst-postać | |
| - "text_tail_associations": asocjacje tekst-ogon | |
| - "character_cluster_labels": etykiety klastrów dla postaci | |
| - "is_essential_text": flagi czy tekst to dialog | |
| - "character_names": przypisane imiona postaci (jeśli dostępne) | |
| - "ocr": rozpoznany tekst (jeśli do_ocr=True) | |
| """ | |
| texts: List[List[List[float]]] = [] | |
| characters: List[List[List[float]]] = [] | |
| character_clusters: List[List[int]] = [] | |
| # Przygotowanie iteratora z opcjonalnym paskiem postępu | |
| if use_tqdm: | |
| from tqdm import tqdm | |
| iterator: Any = tqdm(range(0, len(pages_in_order), batch_size)) | |
| else: | |
| iterator: range = range(0, len(pages_in_order), batch_size) | |
| # Przetwarzanie stron w batchach | |
| per_page_results: List[Dict[str, Any]] = [] | |
| for i in iterator: | |
| pages: List[NDArray[np.uint8]] = pages_in_order[i:i+batch_size] | |
| results: List[Dict[str, Any] | |
| ] = self.predict_detections_and_associations(pages) | |
| per_page_results.extend([result for result in results]) | |
| # Ekstrakcja wyników detekcji dla każdej strony | |
| texts = [result["texts"] for result in per_page_results] | |
| characters = [result["characters"] for result in per_page_results] | |
| character_clusters = [result["character_cluster_labels"] | |
| for result in per_page_results] | |
| # Przypisanie imion postaci na podstawie banku znanych postaci | |
| assigned_character_names: List[str] = self.assign_names_to_characters( | |
| pages_in_order, characters, character_bank, character_clusters, eta=eta) | |
| # Opcjonalne rozpoznawanie tekstu (OCR) | |
| if do_ocr: | |
| ocr: List[List[str]] = self.predict_ocr( | |
| pages_in_order, texts, use_tqdm=use_tqdm) | |
| # Dodawanie przypisanych imion i OCR do wyników dla każdej strony | |
| offset_characters: int = 0 | |
| iteration_over: Any = zip( | |
| per_page_results, ocr) if do_ocr else per_page_results | |
| for iter in iteration_over: | |
| if do_ocr: | |
| result: Dict[str, Any] | |
| ocr_for_page: List[str] | |
| result, ocr_for_page = iter | |
| result["ocr"] = ocr_for_page | |
| else: | |
| result = iter | |
| result["character_names"] = assigned_character_names[offset_characters: | |
| offset_characters + len(result["characters"])] | |
| offset_characters += len(result["characters"]) | |
| return per_page_results | |
| def assign_names_to_characters( | |
| self, | |
| images: List[NDArray[np.uint8]], | |
| character_bboxes: List[List[List[float]]], | |
| character_bank: Dict[str, Any], | |
| character_clusters: List[List[int]], | |
| eta: float = 0.75 | |
| ) -> List[str]: | |
| """ | |
| Przypisuje imiona postaci wykrytym w rozdziale na podstawie banku znanych postaci. | |
| Metoda wykorzystuje: | |
| 1. Embeddingi wizualne wykrytych postaci | |
| 2. Embeddingi postaci z banku referencyjnego | |
| 3. Ograniczenia must-link (postaci z tego samego klastra muszą mieć to samo imię) | |
| 4. Ograniczenia cannot-link (postaci z różnych klastrów nie mogą mieć tego samego imienia) | |
| 5. Problem Optimal Transport z programowaniem liniowym (PuLP) do znalezienia | |
| optymalnego przypisania postaci do imion | |
| Args: | |
| images: Lista obrazów stron z całego rozdziału | |
| character_bboxes: Lista bounding boxów postaci dla każdego obrazu | |
| (list of lists of bboxes) | |
| character_bank: Słownik z bankiem znanych postaci: | |
| - "images": obrazy referencyjne postaci | |
| - "names": imiona odpowiadające obrazom | |
| character_clusters: Etykiety klastrów dla postaci na każdej stronie | |
| (postaci z tym samym ID to prawdopodobnie ta sama osoba) | |
| eta: Parametr kosztu dla opcji "Other" (nieznana postać). | |
| Wyższa wartość = więcej postaci zostanie oznaczonych jako "Other" | |
| Returns: | |
| Lista imion przypisanych do wszystkich wykrytych postaci w kolejności | |
| (lista płaska - imię dla każdej postaci ze wszystkich stron) | |
| """ | |
| # Jeśli bank postaci jest pusty, wszystkie postaci oznaczamy jako "Other" | |
| if len(character_bank["images"]) == 0: | |
| return ["Other" for bboxes_for_image in character_bboxes for bbox in bboxes_for_image] | |
| # Tworzenie embeddingów dla wszystkich postaci w rozdziale | |
| chapter_wide_char_embeddings: List[torch.Tensor] = self.predict_crop_embeddings( | |
| images, character_bboxes) | |
| chapter_wide_char_embeddings_tensor: torch.Tensor = torch.cat( | |
| chapter_wide_char_embeddings, dim=0) | |
| chapter_wide_char_embeddings_normalized: torch.Tensor = torch.nn.functional.normalize( | |
| chapter_wide_char_embeddings_tensor, p=2, dim=1) | |
| chapter_wide_char_embeddings_np: NDArray[np.float32] = chapter_wide_char_embeddings_normalized.cpu( | |
| ).numpy() | |
| # Tworzenie ograniczeń must-link i cannot-link z klastrów postaci | |
| # must-link: postaci z tego samego klastra muszą dostać to samo imię | |
| # cannot-link: postaci z różnych klastrów nie mogą dostać tego samego imienia | |
| must_link: List[Tuple[int, int]] = [] | |
| cannot_link: List[Tuple[int, int]] = [] | |
| offset: int = 0 | |
| for clusters_per_image in character_clusters: | |
| for i in range(len(clusters_per_image)): | |
| for j in range(i+1, len(clusters_per_image)): | |
| if clusters_per_image[i] == clusters_per_image[j]: | |
| must_link.append((offset + i, offset + j)) | |
| else: | |
| cannot_link.append((offset + i, offset + j)) | |
| offset += len(clusters_per_image) | |
| # Tworzenie embeddingów dla postaci z banku referencyjnego | |
| # Używamy pełnego obrazu dla każdej referencyjnej postaci | |
| character_bank_embeddings: List[torch.Tensor] = self.predict_crop_embeddings( | |
| character_bank["images"], [[[0, 0, x.shape[1], x.shape[0]]] for x in character_bank["images"]]) | |
| character_bank_embeddings_tensor: torch.Tensor = torch.cat( | |
| character_bank_embeddings, dim=0) | |
| character_bank_embeddings_normalized: torch.Tensor = torch.nn.functional.normalize( | |
| character_bank_embeddings_tensor, p=2, dim=1) | |
| character_bank_embeddings_np: NDArray[np.float32] = character_bank_embeddings_normalized.cpu( | |
| ).numpy() | |
| # Obliczanie macierzy kosztów (odległości między embeddingami) | |
| costs: NDArray[np.float32] = scipy.spatial.distance.cdist( | |
| chapter_wide_char_embeddings_np, character_bank_embeddings_np) | |
| # Dodanie opcji "Other" (nieznana postać) jako dodatkowa kolumna w macierzy kosztów | |
| none_of_the_above: NDArray[np.float32] = eta * \ | |
| np.ones((costs.shape[0], 1)) | |
| costs = np.concatenate([costs, none_of_the_above], axis=1) | |
| # Konfiguracja problemu optymalizacji (minimalizacja kosztu przypisania) | |
| sense: int = pulp.LpMinimize | |
| num_supply: int | |
| num_demand: int | |
| num_supply, num_demand = costs.shape | |
| problem: pulp.LpProblem = pulp.LpProblem( | |
| "Optimal_Transport_Problem", sense) | |
| # Zmienne binarne: x[(i,j)] = 1 gdy postać i jest przypisana do imienia j | |
| x: Dict[Tuple[int, int], pulp.LpVariable] = pulp.LpVariable.dicts("x", ((i, j) for i in range( | |
| num_supply) for j in range(num_demand)), cat='Binary') | |
| # Funkcja celu: minimalizacja całkowitego kosztu przypisania | |
| problem += pulp.lpSum([costs[i][j] * x[(i, j)] | |
| for i in range(num_supply) for j in range(num_demand)]) | |
| # Ograniczenie: każda wykryta postać musi być przypisana dokładnie do jednego imienia | |
| for i in range(num_supply): | |
| problem += pulp.lpSum([x[(i, j)] for j in range(num_demand)] | |
| ) == 1, f"Supply_{i}_Total_Assignment" | |
| # Ograniczenia cannot-link: postaci z różnych klastrów nie mogą mieć tego samego imienia | |
| for j in range(num_demand-1): # -1 bo ostatnia kolumna to "Other" | |
| for (s1, s2) in cannot_link: | |
| problem += x[(s1, j)] + x[(s2, j) | |
| ] <= 1, f"Exclusion_{s1}_{s2}_Demand_{j}" | |
| # Ograniczenia must-link: postaci z tego samego klastra muszą mieć to samo imię | |
| for j in range(num_demand): | |
| for (s1, s2) in must_link: | |
| problem += x[(s1, j)] - x[(s2, j) | |
| ] == 0, f"Inclusion_{s1}_{s2}_Demand_{j}" | |
| # Rozwiązanie problemu optymalizacji | |
| problem.solve() | |
| # Ekstrakcja wyników (które postaci zostały przypisane do których imion) | |
| assignments: List[Tuple[int, int]] = [] | |
| for v in problem.variables(): | |
| if v.varValue is not None and v.varValue > 0: | |
| index: str | |
| assignment: str | |
| index, assignment = v.name.split( | |
| "(")[1].split(")")[0].split(",") | |
| assignment = assignment[1:] # Usunięcie spacji na początku | |
| assignments.append((int(index), int(assignment))) | |
| # Tworzenie listy etykiet (indeksów imion) dla każdej postaci | |
| labels: NDArray[np.float64] = np.zeros(num_supply) | |
| for i, j in assignments: | |
| labels[i] = j | |
| # Mapowanie indeksów na rzeczywiste imiona (lub "Other") | |
| return [character_bank["names"][int(i)] if i < len(character_bank["names"]) else "Other" for i in labels] | |
| def predict_detections_and_associations( | |
| self, | |
| images: List[NDArray[np.uint8]], | |
| move_to_device_fn: Optional[Callable[[Any], Any]] = None, | |
| character_detection_threshold: float = 0.3, | |
| panel_detection_threshold: float = 0.2, | |
| text_detection_threshold: float = 0.3, | |
| tail_detection_threshold: float = 0.34, | |
| character_character_matching_threshold: float = 0.65, | |
| text_character_matching_threshold: float = 0.35, | |
| text_tail_matching_threshold: float = 0.3, | |
| text_classification_threshold: float = 0.5, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Wykrywa obiekty i ich asocjacje na obrazach stron komiksu/mangi. | |
| Metoda wykonuje następujące kroki: | |
| 1. Detekcję obiektów: panele, postaci, tekst, ogony dymków | |
| 2. Klasyfikację wykrytych obiektów i ich bounding boxów | |
| 3. Filtrowanie detekcji na podstawie progów prawdopodobieństwa | |
| 4. Obliczanie macierzy podobieństwa (affinity matrices): | |
| - text-character: który tekst należy do której postaci | |
| - character-character: które postaci to ta sama osoba | |
| - text-tail: który tekst należy do którego ogona dymku | |
| 5. Przypisywanie asocjacji na podstawie macierzy podobieństwa | |
| 6. Sortowanie paneli w kolejności czytania | |
| 7. Sortowanie tekstów w kolejności czytania w ramach paneli | |
| Args: | |
| images: Lista obrazów do przetworzenia (numpy arrays w formacie HWC) | |
| move_to_device_fn: Funkcja do przenoszenia danych na urządzenie. | |
| Jeśli None, użyje self.move_to_device | |
| character_detection_threshold: Próg prawdopodobieństwa dla detekcji postaci (0-1) | |
| panel_detection_threshold: Próg prawdopodobieństwa dla detekcji paneli (0-1) | |
| text_detection_threshold: Próg prawdopodobieństwa dla detekcji tekstu (0-1) | |
| tail_detection_threshold: Próg prawdopodobieństwa dla detekcji ogonów (0-1) | |
| character_character_matching_threshold: Próg podobieństwa dla dopasowania postaci (0-1) | |
| text_character_matching_threshold: Próg podobieństwa dla dopasowania tekst-postać (0-1) | |
| text_tail_matching_threshold: Próg podobieństwa dla dopasowania tekst-ogon (0-1) | |
| text_classification_threshold: Próg klasyfikacji czy tekst to dialog (0-1) | |
| Returns: | |
| Lista słowników, jeden dla każdego obrazu, zawierających: | |
| - "panels": lista bounding boxów paneli [x1, y1, x2, y2] | |
| - "texts": lista bounding boxów tekstu [x1, y1, x2, y2] | |
| - "characters": lista bounding boxów postaci [x1, y1, x2, y2] | |
| - "tails": lista bounding boxów ogonów dymków [x1, y1, x2, y2] | |
| - "text_character_associations": lista par [idx_tekstu, idx_postaci] | |
| - "text_tail_associations": lista par [idx_tekstu, idx_ogona] | |
| - "character_cluster_labels": etykiety klastrów dla postaci (list of int) | |
| - "is_essential_text": lista flag bool czy dany tekst to dialog | |
| """ | |
| assert not self.config.disable_detections | |
| move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn | |
| # Preprocessing obrazów dla transformera detekcji | |
| inputs_to_detection_transformer: Dict[str, torch.Tensor] = self.processor.preprocess_inputs_for_detection( | |
| images) | |
| inputs_to_detection_transformer = move_to_device_fn( | |
| inputs_to_detection_transformer) | |
| # Przepuszczenie przez transformer detekcji obiektów | |
| detection_transformer_output: ConditionalDetrModelOutput = self._get_detection_transformer_output( | |
| **inputs_to_detection_transformer) | |
| # Pobranie predykcji klas i bounding boxów | |
| predicted_class_scores: torch.Tensor | |
| predicted_bboxes: torch.Tensor | |
| predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes( | |
| detection_transformer_output) | |
| # Przygotowanie rozmiarów oryginalnych obrazów do skalowania bounding boxów | |
| original_image_sizes: torch.Tensor = torch.stack([torch.tensor( | |
| img.shape[:2]) for img in images], dim=0).to(predicted_bboxes.device) | |
| # Konwersja scorów na prawdopodobieństwa i wybranie najlepszych klas | |
| batch_scores: torch.Tensor | |
| batch_labels: torch.Tensor | |
| batch_scores, batch_labels = predicted_class_scores.max(-1) | |
| batch_scores = batch_scores.sigmoid() # Konwersja logitów na prawdopodobieństwa | |
| batch_labels = batch_labels.long() | |
| # Konwersja bounding boxów z formatu center (cx, cy, w, h) na corners (x1, y1, x2, y2) | |
| batch_bboxes: torch.Tensor = center_to_corners_format(predicted_bboxes) | |
| # Skalowanie bounding boxów z powrotem do oryginalnych rozmiarów obrazu | |
| if isinstance(original_image_sizes, List): | |
| img_h: torch.Tensor = torch.Tensor( | |
| [i[0] for i in original_image_sizes]) | |
| img_w: torch.Tensor = torch.Tensor( | |
| [i[1] for i in original_image_sizes]) | |
| else: | |
| img_h: torch.Tensor | |
| img_w: torch.Tensor | |
| img_h, img_w = original_image_sizes.unbind(1) | |
| scale_fct: torch.Tensor = torch.stack( | |
| [img_w, img_h, img_w, img_h], dim=1).to(batch_bboxes.device) | |
| batch_bboxes = batch_bboxes * scale_fct[:, None, :] | |
| # Filtrowanie detekcji na podstawie progów dla każdego typu obiektu | |
| batch_panel_indices: List[torch.Tensor] = self.processor._get_indices_of_panels_to_keep( | |
| batch_scores, batch_labels, batch_bboxes, panel_detection_threshold) | |
| batch_character_indices: List[torch.Tensor] = self.processor._get_indices_of_characters_to_keep( | |
| batch_scores, batch_labels, batch_bboxes, character_detection_threshold) | |
| batch_text_indices: List[torch.Tensor] = self.processor._get_indices_of_texts_to_keep( | |
| batch_scores, batch_labels, batch_bboxes, text_detection_threshold) | |
| batch_tail_indices: List[torch.Tensor] = self.processor._get_indices_of_tails_to_keep( | |
| batch_scores, batch_labels, batch_bboxes, tail_detection_threshold) | |
| # Ekstrakcja tokenów z outputu transformera dla różnych zadań | |
| # Tokeny obiektów - reprezentacje dla każdego wykrytego obiektu | |
| predicted_obj_tokens_for_batch: torch.Tensor = self._get_predicted_obj_tokens( | |
| detection_transformer_output) | |
| # Token t2c - specjalny token do zadania text-to-character matching | |
| predicted_t2c_tokens_for_batch: torch.Tensor = self._get_predicted_t2c_tokens( | |
| detection_transformer_output) | |
| # Token c2c - specjalny token do zadania character-to-character matching | |
| predicted_c2c_tokens_for_batch: torch.Tensor = self._get_predicted_c2c_tokens( | |
| detection_transformer_output) | |
| # Obliczanie macierzy podobieństwa tekst-postać (kto mówi) | |
| text_character_affinity_matrices: List[torch.Tensor] = self._get_text_character_affinity_matrices( | |
| character_obj_tokens_for_batch=[x[i] for x, i in zip( | |
| predicted_obj_tokens_for_batch, batch_character_indices)], | |
| text_obj_tokens_for_this_batch=[x[i] for x, i in zip( | |
| predicted_obj_tokens_for_batch, batch_text_indices)], | |
| t2c_tokens_for_batch=predicted_t2c_tokens_for_batch, | |
| apply_sigmoid=True, | |
| ) | |
| # Przygotowanie bounding boxów postaci do ekstrakcji embeddingów | |
| character_bboxes_in_batch: List[torch.Tensor] = [batch_bboxes[i][j] | |
| for i, j in enumerate(batch_character_indices)] | |
| # Obliczanie macierzy podobieństwa postać-postać (clustering postaci) | |
| character_character_affinity_matrices: List[torch.Tensor] = self._get_character_character_affinity_matrices( | |
| character_obj_tokens_for_batch=[x[i] for x, i in zip( | |
| predicted_obj_tokens_for_batch, batch_character_indices)], | |
| crop_embeddings_for_batch=self.predict_crop_embeddings( | |
| images, character_bboxes_in_batch, move_to_device_fn), | |
| c2c_tokens_for_batch=predicted_c2c_tokens_for_batch, | |
| apply_sigmoid=True, | |
| ) | |
| # Obliczanie macierzy podobieństwa tekst-ogon (który tekst należy do którego dymku) | |
| text_tail_affinity_matrices: List[torch.Tensor] = self._get_text_tail_affinity_matrices( | |
| text_obj_tokens_for_this_batch=[x[i] for x, i in zip( | |
| predicted_obj_tokens_for_batch, batch_text_indices)], | |
| tail_obj_tokens_for_batch=[x[i] for x, i in zip( | |
| predicted_obj_tokens_for_batch, batch_tail_indices)], | |
| apply_sigmoid=True, | |
| ) | |
| # Klasyfikacja czy tekst to dialog (vs naracja/efekt dźwiękowy) | |
| is_this_text_a_dialogue: List[torch.Tensor] = self._get_text_classification( | |
| [x[i] for x, i in zip(predicted_obj_tokens_for_batch, batch_text_indices)]) | |
| # Przygotowanie wyników dla każdego obrazu w batchu | |
| results: List[Dict[str, Any]] = [] | |
| for batch_index in range(len(batch_scores)): | |
| # Pobranie indeksów wykrytych obiektów dla tego obrazu | |
| panel_indices: torch.Tensor = batch_panel_indices[batch_index] | |
| character_indices: torch.Tensor = batch_character_indices[batch_index] | |
| text_indices: torch.Tensor = batch_text_indices[batch_index] | |
| tail_indices: torch.Tensor = batch_tail_indices[batch_index] | |
| # Ekstrakcja bounding boxów dla każdego typu obiektu | |
| character_bboxes: torch.Tensor = batch_bboxes[batch_index][character_indices] | |
| panel_bboxes: torch.Tensor = batch_bboxes[batch_index][panel_indices] | |
| text_bboxes: torch.Tensor = batch_bboxes[batch_index][text_indices] | |
| tail_bboxes: torch.Tensor = batch_bboxes[batch_index][tail_indices] | |
| # Sortowanie paneli w kolejności czytania (góra->dół, prawo->lewo dla mangi) | |
| local_sorted_panel_indices: torch.Tensor = sort_panels( | |
| panel_bboxes) | |
| panel_bboxes = panel_bboxes[local_sorted_panel_indices] | |
| # Sortowanie tekstów w kolejności czytania w ramach paneli | |
| local_sorted_text_indices: torch.Tensor = sort_text_boxes_in_reading_order( | |
| text_bboxes, panel_bboxes) | |
| text_bboxes = text_bboxes[local_sorted_text_indices] | |
| # Pobranie scorów podobieństwa dla tego obrazu (z zachowaniem kolejności sortowania) | |
| character_character_matching_scores: torch.Tensor = character_character_affinity_matrices[ | |
| batch_index] | |
| text_character_matching_scores: torch.Tensor = text_character_affinity_matrices[ | |
| batch_index][local_sorted_text_indices] | |
| text_tail_matching_scores: torch.Tensor = text_tail_affinity_matrices[ | |
| batch_index][local_sorted_text_indices] | |
| # Klasyfikacja tekstów jako dialog/nie-dialog | |
| is_essential_text: torch.Tensor = is_this_text_a_dialogue[batch_index][ | |
| local_sorted_text_indices] > text_classification_threshold | |
| # Clustering postaci na podstawie macierzy podobieństwa (Union-Find algorithm) | |
| # Postaci z tym samym cluster_label to prawdopodobnie ta sama osoba | |
| character_cluster_labels: List[int] = UnionFind.from_adj_matrix( | |
| character_character_matching_scores > character_character_matching_threshold | |
| ).get_labels_for_connected_components() | |
| # Tworzenie asocjacji tekst-postać (przypisywanie mówiącego do każdego tekstu) | |
| if 0 in text_character_matching_scores.shape: | |
| # Brak tekstów lub postaci - pusta lista asocjacji | |
| text_character_associations: torch.Tensor = torch.zeros( | |
| (0, 2), dtype=torch.long) | |
| else: | |
| # Dla każdego tekstu znajdź najbardziej prawdopodobną mówiącą postać | |
| most_likely_speaker_for_each_text: torch.Tensor = torch.argmax( | |
| text_character_matching_scores, dim=1) | |
| text_indices_tensor: torch.Tensor = torch.arange(len(text_bboxes)).type_as( | |
| most_likely_speaker_for_each_text) | |
| text_character_associations: torch.Tensor = torch.stack( | |
| [text_indices_tensor, most_likely_speaker_for_each_text], dim=1) | |
| # Filtrowanie - zachowaj tylko asocjacje powyżej progu pewności | |
| to_keep: torch.Tensor = text_character_matching_scores.max( | |
| dim=1).values > text_character_matching_threshold | |
| text_character_associations = text_character_associations[to_keep] | |
| # Tworzenie asocjacji tekst-ogon (przypisywanie ogona dymku do tekstu) | |
| if 0 in text_tail_matching_scores.shape: | |
| # Brak tekstów lub ogonów - pusta lista asocjacji | |
| text_tail_associations: torch.Tensor = torch.zeros( | |
| (0, 2), dtype=torch.long) | |
| else: | |
| # Dla każdego tekstu znajdź najbardziej prawdopodobny ogon | |
| most_likely_tail_for_each_text: torch.Tensor = torch.argmax( | |
| text_tail_matching_scores, dim=1) | |
| text_indices_tensor: torch.Tensor = torch.arange(len(text_bboxes)).type_as( | |
| most_likely_tail_for_each_text) | |
| text_tail_associations: torch.Tensor = torch.stack( | |
| [text_indices_tensor, most_likely_tail_for_each_text], dim=1) | |
| # Filtrowanie - zachowaj tylko asocjacje powyżej progu pewności | |
| to_keep: torch.Tensor = text_tail_matching_scores.max( | |
| dim=1).values > text_tail_matching_threshold | |
| text_tail_associations = text_tail_associations[to_keep] | |
| # Dodanie wyników dla tego obrazu do listy | |
| results.append({ | |
| "panels": panel_bboxes.tolist(), | |
| "texts": text_bboxes.tolist(), | |
| "characters": character_bboxes.tolist(), | |
| "tails": tail_bboxes.tolist(), | |
| "text_character_associations": text_character_associations.tolist(), | |
| "text_tail_associations": text_tail_associations.tolist(), | |
| "character_cluster_labels": character_cluster_labels, | |
| "is_essential_text": is_essential_text.tolist(), | |
| }) | |
| return results | |
| def get_affinity_matrices_given_annotations( | |
| self, | |
| images: List[NDArray[np.uint8]], | |
| annotations: List[Dict[str, Any]], | |
| move_to_device_fn: Optional[Callable[[Any], Any]] = None, | |
| apply_sigmoid: bool = True | |
| ) -> Dict[str, List[torch.Tensor]]: | |
| """ | |
| Oblicza macierze podobieństwa (affinity matrices) dla anotowanych danych. | |
| Ta metoda jest używana głównie podczas treningu lub ewaluacji, gdy mamy ground truth | |
| annotations. Zamiast używać progów detekcji, używa dopasowania Hungarian Matcher | |
| między predykcjami a ground truth, aby wybrać odpowiednie tokeny dla każdego obiektu. | |
| Args: | |
| images: Lista obrazów do przetworzenia (numpy arrays) | |
| annotations: Lista anotacji dla każdego obrazu, każda zawiera: | |
| - "bboxes_as_x1y1x2y2": lista bounding boxów w formacie [x1,y1,x2,y2] | |
| - "labels": lista etykiet klas dla każdego bbox | |
| (0=postać, 1=tekst, 2=panel, 3=ogon) | |
| move_to_device_fn: Funkcja do przenoszenia danych na urządzenie | |
| apply_sigmoid: Czy aplikować sigmoid do scorów podobieństwa (konwersja logitów->prawdop.) | |
| Returns: | |
| Słownik zawierający: | |
| - "text_character_affinity_matrices": lista macierzy [num_texts, num_characters] | |
| - "character_character_affinity_matrices": lista macierzy [num_chars, num_chars] | |
| - "character_character_affinity_matrices_crop_only": j.w. ale tylko z embeddingów | |
| - "text_tail_affinity_matrices": lista macierzy [num_texts, num_tails] | |
| - "is_this_text_a_dialogue": lista tensorów klasyfikacji tekstu | |
| """ | |
| assert not self.config.disable_detections | |
| move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn | |
| # Ekstrakcja bounding boxów postaci z anotacji (label 0 = postać) | |
| character_bboxes_in_batch: List[List[List[float]]] = [[bbox for bbox, label in zip( | |
| a["bboxes_as_x1y1x2y2"], a["labels"]) if label == 0] for a in annotations] | |
| crop_embeddings_for_batch: List[torch.Tensor] = self.predict_crop_embeddings( | |
| images, character_bboxes_in_batch, move_to_device_fn) | |
| # Preprocessing danych wejściowych dla transformera detekcji (z anotacjami) | |
| inputs_to_detection_transformer: Dict[str, torch.Tensor] = self.processor.preprocess_inputs_for_detection( | |
| images, annotations) | |
| inputs_to_detection_transformer = move_to_device_fn( | |
| inputs_to_detection_transformer) | |
| # Wyciągnięcie przetworzonej listy targetów (usunięcie z inputs) | |
| processed_targets: List[Dict[str, torch.Tensor] | |
| ] = inputs_to_detection_transformer.pop("labels") | |
| # Przepuszczenie przez transformer detekcji | |
| detection_transformer_output: ConditionalDetrModelOutput = self._get_detection_transformer_output( | |
| **inputs_to_detection_transformer) | |
| # Ekstrakcja różnych typów tokenów z outputu transformera | |
| predicted_obj_tokens_for_batch: torch.Tensor = self._get_predicted_obj_tokens( | |
| detection_transformer_output) | |
| predicted_t2c_tokens_for_batch: torch.Tensor = self._get_predicted_t2c_tokens( | |
| detection_transformer_output) | |
| predicted_c2c_tokens_for_batch: torch.Tensor = self._get_predicted_c2c_tokens( | |
| detection_transformer_output) | |
| # Predykcja klas i bounding boxów | |
| predicted_class_scores: torch.Tensor | |
| predicted_bboxes: torch.Tensor | |
| predicted_class_scores, predicted_bboxes = self._get_predicted_bboxes_and_classes( | |
| detection_transformer_output) | |
| # Przygotowanie danych do Hungarian matchera | |
| matching_dict: Dict[str, torch.Tensor] = { | |
| "logits": predicted_class_scores, | |
| "pred_boxes": predicted_bboxes, | |
| } | |
| # Wykonanie dopasowania węgierskiego między predykcjami a ground truth | |
| indices: List[Tuple[torch.Tensor, torch.Tensor] | |
| ] = self.matcher(matching_dict, processed_targets) | |
| # Listy do przechowania dopasowanych tokenów dla każdego typu obiektu | |
| matched_char_obj_tokens_for_batch: List[torch.Tensor] = [] | |
| matched_text_obj_tokens_for_batch: List[torch.Tensor] = [] | |
| matched_tail_obj_tokens_for_batch: List[torch.Tensor] = [] | |
| t2c_tokens_for_batch: List[torch.Tensor] = [] | |
| c2c_tokens_for_batch: List[torch.Tensor] = [] | |
| # Dla każdego obrazu w batchu, ekstrakcja dopasowanych tokenów | |
| for j, (pred_idx, tgt_idx) in enumerate(indices): | |
| # Mapowanie: indeks w targetach -> indeks w predykcjach | |
| target_idx_to_pred_idx: Dict[int, int] = {tgt.item(): pred.item() | |
| for pred, tgt in zip(pred_idx, tgt_idx)} | |
| targets_for_this_image: Dict[str, | |
| torch.Tensor] = processed_targets[j] | |
| # Znajdź indeksy obiektów każdego typu w anotacjach | |
| # label 1 = tekst | |
| indices_of_text_boxes_in_annotation: List[int] = [i for i, label in enumerate( | |
| targets_for_this_image["class_labels"]) if label == 1] | |
| # label 0 = postać | |
| indices_of_char_boxes_in_annotation: List[int] = [i for i, label in enumerate( | |
| targets_for_this_image["class_labels"]) if label == 0] | |
| # label 3 = ogon dymku | |
| indices_of_tail_boxes_in_annotation: List[int] = [i for i, label in enumerate( | |
| targets_for_this_image["class_labels"]) if label == 3] | |
| # Zmapowanie indeksów targetów na indeksy predykcji | |
| predicted_text_indices: List[int] = [target_idx_to_pred_idx[i] | |
| for i in indices_of_text_boxes_in_annotation] | |
| predicted_char_indices: List[int] = [target_idx_to_pred_idx[i] | |
| for i in indices_of_char_boxes_in_annotation] | |
| predicted_tail_indices: List[int] = [target_idx_to_pred_idx[i] | |
| for i in indices_of_tail_boxes_in_annotation] | |
| # Wyciągnięcie tokenów odpowiadających dopasowanym obiektom | |
| matched_char_obj_tokens_for_batch.append( | |
| predicted_obj_tokens_for_batch[j][predicted_char_indices]) | |
| matched_text_obj_tokens_for_batch.append( | |
| predicted_obj_tokens_for_batch[j][predicted_text_indices]) | |
| matched_tail_obj_tokens_for_batch.append( | |
| predicted_obj_tokens_for_batch[j][predicted_tail_indices]) | |
| # Dodanie tokenów specjalnych dla tego obrazu | |
| t2c_tokens_for_batch.append(predicted_t2c_tokens_for_batch[j]) | |
| c2c_tokens_for_batch.append(predicted_c2c_tokens_for_batch[j]) | |
| # Obliczanie macierzy podobieństwa tekst-postać (speaker assignment) | |
| text_character_affinity_matrices: List[torch.Tensor] = self._get_text_character_affinity_matrices( | |
| character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch, | |
| text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch, | |
| t2c_tokens_for_batch=t2c_tokens_for_batch, | |
| apply_sigmoid=apply_sigmoid, | |
| ) | |
| # Obliczanie macierzy podobieństwa postać-postać (character clustering) | |
| # Używa zarówno tokenów z transformera jak i embeddingów z ViT-MAE | |
| character_character_affinity_matrices: List[torch.Tensor] = self._get_character_character_affinity_matrices( | |
| character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch, | |
| crop_embeddings_for_batch=crop_embeddings_for_batch, | |
| c2c_tokens_for_batch=c2c_tokens_for_batch, | |
| apply_sigmoid=apply_sigmoid, | |
| ) | |
| # Obliczanie macierzy podobieństwa postać-postać TYLKO na podstawie embeddingów | |
| # (bez tokenów z transformera, crop_only=True) | |
| character_character_affinity_matrices_crop_only: List[torch.Tensor] = self._get_character_character_affinity_matrices( | |
| character_obj_tokens_for_batch=matched_char_obj_tokens_for_batch, | |
| crop_embeddings_for_batch=crop_embeddings_for_batch, | |
| c2c_tokens_for_batch=c2c_tokens_for_batch, | |
| crop_only=True, | |
| apply_sigmoid=apply_sigmoid, | |
| ) | |
| # Obliczanie macierzy podobieństwa tekst-ogon (text-to-tail matching) | |
| text_tail_affinity_matrices: List[torch.Tensor] = self._get_text_tail_affinity_matrices( | |
| text_obj_tokens_for_this_batch=matched_text_obj_tokens_for_batch, | |
| tail_obj_tokens_for_batch=matched_tail_obj_tokens_for_batch, | |
| apply_sigmoid=apply_sigmoid, | |
| ) | |
| # Klasyfikacja czy tekst to dialog (vs naracja/efekt dźwiękowy) | |
| is_this_text_a_dialogue: List[torch.Tensor] = self._get_text_classification( | |
| matched_text_obj_tokens_for_batch, apply_sigmoid=apply_sigmoid) | |
| return { | |
| "text_character_affinity_matrices": text_character_affinity_matrices, | |
| "character_character_affinity_matrices": character_character_affinity_matrices, | |
| "character_character_affinity_matrices_crop_only": character_character_affinity_matrices_crop_only, | |
| "text_tail_affinity_matrices": text_tail_affinity_matrices, | |
| "is_this_text_a_dialogue": is_this_text_a_dialogue, | |
| } | |
| def predict_crop_embeddings( | |
| self, | |
| images: List[NDArray[np.uint8]], | |
| crop_bboxes: List[List[List[float]]], | |
| move_to_device_fn: Optional[Callable[[Any], Any]] = None, | |
| mask_ratio: float = 0.0, | |
| batch_size: int = 256 | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Tworzy embeddingi wektorowe dla wyciętych fragmentów obrazów (crops). | |
| Metoda wykorzystuje model ViT-MAE (Vision Transformer - Masked Autoencoder) | |
| do tworzenia reprezentacji wektorowych dla regionów obrazu określonych przez | |
| bounding boxy. Embeddingi są używane głównie do dopasowywania postaci | |
| (character-character matching). | |
| Args: | |
| images: Lista obrazów źródłowych (numpy arrays w formacie HWC) | |
| crop_bboxes: Lista list bounding boxów dla każdego obrazu. | |
| Format bbox: [x1, y1, x2, y2] (corners format) | |
| move_to_device_fn: Funkcja do przenoszenia danych na urządzenie | |
| mask_ratio: Współczynnik maskowania dla ViT-MAE (0.0 = bez maskowania, | |
| wyższe wartości = więcej zamaskowanych patchów). Domyślnie 0.0 | |
| dla inferencji (chcemy pełne embeddingi bez rekonstrukcji) | |
| batch_size: Maksymalna liczba crops przetwarzanych jednocześnie | |
| (kontrola zużycia pamięci GPU) | |
| Returns: | |
| Lista tensorów embeddingów, jeden tensor dla każdego obrazu. | |
| Każdy tensor ma kształt [num_crops, hidden_size]. | |
| Jeśli moduł embedowania jest wyłączony, zwraca listę pustych tensorów. | |
| """ | |
| if self.config.disable_crop_embeddings: | |
| return None | |
| assert isinstance( | |
| crop_bboxes, List), "please provide a list of bboxes for each image to get embeddings for" | |
| move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn | |
| # Tymczasowa zmiana mask_ratio z wartości domyślnej na określoną | |
| # (zapisujemy starą wartość do przywrócenia później) | |
| old_mask_ratio: float = self.crop_embedding_model.embeddings.config.mask_ratio | |
| self.crop_embedding_model.embeddings.config.mask_ratio = mask_ratio | |
| # Wycinanie fragmentów obrazów zgodnie z bounding boxami | |
| crops_per_image: List[NDArray[np.uint8]] = [] | |
| num_crops_per_batch: List[int] = [ | |
| len(bboxes) for bboxes in crop_bboxes] | |
| for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch): | |
| crops: List[NDArray[np.uint8] | |
| ] = self.processor.crop_image(image, bboxes) | |
| assert len(crops) == num_crops | |
| crops_per_image.extend(crops) | |
| # Jeśli brak crops, zwróć puste tensory odpowiedniego kształtu | |
| if len(crops_per_image) == 0: | |
| return [move_to_device_fn(torch.zeros(0, self.config.crop_embedding_model_config.hidden_size)) for _ in crop_bboxes] | |
| # Preprocessing crops (normalizacja, resize, konwersja na tensor) | |
| crops_per_image_tensor: torch.Tensor = self.processor.preprocess_inputs_for_crop_embeddings( | |
| crops_per_image) | |
| crops_per_image_tensor = move_to_device_fn(crops_per_image_tensor) | |
| # Przetwarzanie crops w batchach aby uniknąć OOM (Out Of Memory) | |
| embeddings: List[torch.Tensor] = [] | |
| for i in range(0, len(crops_per_image_tensor), batch_size): | |
| crops: torch.Tensor = crops_per_image_tensor[i:i+batch_size] | |
| # Pobieramy token [CLS] (indeks 0) jako reprezentację całego cropu | |
| embeddings_per_batch: torch.Tensor = self.crop_embedding_model( | |
| crops).last_hidden_state[:, 0] | |
| embeddings.append(embeddings_per_batch) | |
| embeddings_concat: torch.Tensor = torch.cat(embeddings, dim=0) | |
| # Rozdzielenie embeddingów z powrotem na grupy odpowiadające obrazom | |
| crop_embeddings_for_batch: List[torch.Tensor] = [] | |
| for num_crops in num_crops_per_batch: | |
| crop_embeddings_for_batch.append(embeddings_concat[:num_crops]) | |
| embeddings_concat = embeddings_concat[num_crops:] | |
| # Przywrócenie oryginalnego mask_ratio | |
| self.crop_embedding_model.embeddings.config.mask_ratio = old_mask_ratio | |
| return crop_embeddings_for_batch | |
| def predict_ocr( | |
| self, | |
| images: List[NDArray[np.uint8]], | |
| crop_bboxes: List[List[List[float]]], | |
| move_to_device_fn: Optional[Callable[[Any], Any]] = None, | |
| use_tqdm: bool = False, | |
| batch_size: int = 32, | |
| max_new_tokens: int = 64 | |
| ) -> List[List[str]]: | |
| """ | |
| Rozpoznaje tekst (OCR) w określonych regionach obrazów. | |
| Metoda wykorzystuje model Vision-Encoder-Decoder (VED) do rozpoznawania | |
| tekstu w wyciętych fragmentach obrazu. Encoder przetwarza obraz tekstu, | |
| a decoder generuje sekwencję tokenów tekstowych autoregresywnie. | |
| Args: | |
| images: Lista obrazów źródłowych (numpy arrays) | |
| crop_bboxes: Lista list bounding boxów dla każdego obrazu, | |
| określających regiony z tekstem do rozpoznania. | |
| Format: [x1, y1, x2, y2] | |
| move_to_device_fn: Funkcja do przenoszenia danych na urządzenie | |
| use_tqdm: Czy wyświetlać pasek postępu podczas przetwarzania | |
| batch_size: Liczba crops przetwarzanych jednocześnie (kontrola pamięci) | |
| max_new_tokens: Maksymalna liczba tokenów do wygenerowania dla każdego | |
| fragmentu tekstu (kontrola długości wyjścia) | |
| Returns: | |
| Lista list stringów, jedna lista dla każdego obrazu. | |
| Każdy string to rozpoznany tekst z odpowiadającego bbox. | |
| Jeśli moduł OCR jest wyłączony, podnosi AssertionError. | |
| """ | |
| assert not self.config.disable_ocr | |
| move_to_device_fn = self.move_to_device if move_to_device_fn is None else move_to_device_fn | |
| # Wycinanie fragmentów obrazów z tekstem | |
| crops_per_image: List[NDArray[np.uint8]] = [] | |
| num_crops_per_batch: List[int] = [ | |
| len(bboxes) for bboxes in crop_bboxes] | |
| for image, bboxes, num_crops in zip(images, crop_bboxes, num_crops_per_batch): | |
| crops: List[NDArray[np.uint8] | |
| ] = self.processor.crop_image(image, bboxes) | |
| assert len(crops) == num_crops | |
| crops_per_image.extend(crops) | |
| # Jeśli brak crops, zwróć puste listy | |
| if len(crops_per_image) == 0: | |
| return [[] for _ in crop_bboxes] | |
| # Preprocessing crops dla OCR (normalizacja, resize, konwersja na tensor) | |
| crops_per_image_tensor: torch.Tensor = self.processor.preprocess_inputs_for_ocr( | |
| crops_per_image) | |
| crops_per_image_tensor = move_to_device_fn(crops_per_image_tensor) | |
| # Przetwarzanie crops w batchach aby uniknąć OOM | |
| all_generated_texts: List[str] = [] | |
| if use_tqdm: | |
| from tqdm import tqdm | |
| pbar: Any = tqdm(range(0, len(crops_per_image_tensor), batch_size)) | |
| else: | |
| pbar: range = range(0, len(crops_per_image_tensor), batch_size) | |
| for i in pbar: | |
| crops: torch.Tensor = crops_per_image_tensor[i:i+batch_size] | |
| # Generowanie tekstu autoregresywnie (beam search / greedy decoding) | |
| generated_ids: torch.Tensor = self.ocr_model.generate( | |
| crops, max_new_tokens=max_new_tokens) | |
| # Dekodowanie tokenów ID na stringi tekstowe | |
| generated_texts: List[str] = self.processor.postprocess_ocr_tokens( | |
| generated_ids) | |
| all_generated_texts.extend(generated_texts) | |
| # Rozdzielenie wyników OCR z powrotem na grupy odpowiadające obrazom | |
| texts_for_images: List[List[str]] = [] | |
| for num_crops in num_crops_per_batch: | |
| # Usunięcie znaków nowej linii z rozpoznanego tekstu | |
| texts_for_images.append([x.replace("\n", "") | |
| for x in all_generated_texts[:num_crops]]) | |
| all_generated_texts = all_generated_texts[num_crops:] | |
| return texts_for_images | |
| def visualise_single_image_prediction( | |
| self, | |
| image_as_np_array: NDArray[np.uint8], | |
| predictions: Dict[str, Any], | |
| filename: Optional[str] = None | |
| ) -> Any: | |
| """ | |
| Wizualizuje wyniki predykcji na obrazie. | |
| Rysuje bounding boxy dla wykrytych obiektów (panele, postaci, tekst, ogony) | |
| oraz asocjacje między nimi (linie łączące tekst z postacią, tekst z ogonem). | |
| Args: | |
| image_as_np_array: Obraz do wizualizacji (numpy array w formacie HWC) | |
| predictions: Słownik z wynikami predykcji zawierający klucze: | |
| - "panels", "texts", "characters", "tails": bounding boxy | |
| - "text_character_associations": asocjacje tekst-postać | |
| - "text_tail_associations": asocjacje tekst-ogon | |
| filename: Opcjonalna ścieżka do zapisu wizualizacji (jeśli None, tylko wyświetli) | |
| Returns: | |
| Obiekt wizualizacji (zależny od implementacji funkcji pomocniczej) | |
| """ | |
| return visualise_single_image_prediction(image_as_np_array, predictions, filename) | |
| def _get_detection_transformer_output( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| pixel_mask: Optional[torch.LongTensor] = None | |
| ) -> ConditionalDetrModelOutput: | |
| """ | |
| Przepuszcza obrazy przez transformer detekcji obiektów. | |
| Args: | |
| pixel_values: Tensor z wartościami pikseli obrazów [batch, channels, height, width] | |
| pixel_mask: Opcjonalna maska określająca które piksele są padding | |
| (1 = valid pixel, 0 = padding) | |
| Returns: | |
| Output transformera zawierający: | |
| - last_hidden_state: tokeny dla obiektów i tokenów specjalnych | |
| - reference_points: punkty referencyjne dla predykcji bounding boxów | |
| - intermediate_hidden_states: stany z warstw pośrednich (opcjonalnie) | |
| Raises: | |
| ValueError: Jeśli moduł detekcji jest wyłączony w konfiguracji | |
| """ | |
| if self.config.disable_detections: | |
| raise ValueError( | |
| "Detection model is disabled. Set disable_detections=False in the config.") | |
| return self.detection_transformer( | |
| pixel_values=pixel_values, | |
| pixel_mask=pixel_mask, | |
| return_dict=True | |
| ) | |
| def _get_predicted_obj_tokens( | |
| self, | |
| detection_transformer_output: ConditionalDetrModelOutput | |
| ) -> torch.Tensor: | |
| """ | |
| Ekstraktuje tokeny reprezentujące wykryte obiekty z outputu transformera. | |
| Tokeny obiektów to reprezentacje wektorowe dla każdego wykrytego obiektu | |
| (panele, postaci, tekst, ogony). Ostatnie num_non_obj_tokens tokenów | |
| to tokeny specjalne używane do zadań matching (c2c, t2c, etc.). | |
| Args: | |
| detection_transformer_output: Output z transformera detekcji | |
| Returns: | |
| Tensor tokenów obiektów o kształcie [batch, num_objects, hidden_dim] | |
| """ | |
| return detection_transformer_output.last_hidden_state[:, :-self.num_non_obj_tokens] | |
| def _get_predicted_c2c_tokens( | |
| self, | |
| detection_transformer_output: ConditionalDetrModelOutput | |
| ) -> torch.Tensor: | |
| """ | |
| Ekstraktuje token c2c (character-to-character) z outputu transformera. | |
| Token c2c to specjalny token używany do zadania dopasowywania postaci | |
| do siebie (character clustering). Jest to token na pozycji -num_non_obj_tokens. | |
| Args: | |
| detection_transformer_output: Output z transformera detekcji | |
| Returns: | |
| Tensor tokenu c2c o kształcie [batch, hidden_dim] | |
| """ | |
| return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens] | |
| def _get_predicted_t2c_tokens( | |
| self, | |
| detection_transformer_output: ConditionalDetrModelOutput | |
| ) -> torch.Tensor: | |
| """ | |
| Ekstraktuje token t2c (text-to-character) z outputu transformera. | |
| Token t2c to specjalny token używany do zadania dopasowywania tekstu | |
| do postaci (speaker assignment). Jest to token na pozycji -num_non_obj_tokens+1. | |
| Args: | |
| detection_transformer_output: Output z transformera detekcji | |
| Returns: | |
| Tensor tokenu t2c o kształcie [batch, hidden_dim] | |
| """ | |
| return detection_transformer_output.last_hidden_state[:, -self.num_non_obj_tokens+1] | |
| def _get_predicted_bboxes_and_classes( | |
| self, | |
| detection_transformer_output: ConditionalDetrModelOutput, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Predykcja klas i bounding boxów dla wykrytych obiektów. | |
| Metoda wykorzystuje tokeny obiektów do: | |
| 1. Klasyfikacji każdego obiektu (panel, postać, tekst, ogon) | |
| 2. Predykcji bounding boxa w formacie center (cx, cy, w, h) | |
| Bounding boxy są predykcyjne względem punktów referencyjnych (reference points) | |
| z deformable attention, co poprawia dokładność lokalizacji. | |
| Args: | |
| detection_transformer_output: Output z transformera detekcji | |
| Returns: | |
| Krotka (predicted_class_scores, predicted_boxes): | |
| - predicted_class_scores: logity klas [batch, num_objects, num_classes] | |
| - predicted_boxes: boxy w formacie center [batch, num_objects, 4] | |
| Raises: | |
| ValueError: Jeśli moduł detekcji jest wyłączony | |
| """ | |
| if self.config.disable_detections: | |
| raise ValueError( | |
| "Detection model is disabled. Set disable_detections=False in the config.") | |
| # Pobranie tokenów obiektów (bez tokenów specjalnych) | |
| obj: torch.Tensor = self._get_predicted_obj_tokens( | |
| detection_transformer_output) | |
| # Klasyfikacja obiektów (0=postać, 1=tekst, 2=panel, 3=ogon) | |
| predicted_class_scores: torch.Tensor = self.class_labels_classifier( | |
| obj) | |
| # Pobranie punktów referencyjnych (bez punktów dla tokenów specjalnych) | |
| reference: torch.Tensor = detection_transformer_output.reference_points[:- | |
| self.num_non_obj_tokens] | |
| # Konwersja z przestrzeni sigmoid na logity dla dodawania offsetów | |
| reference_before_sigmoid: torch.Tensor = inverse_sigmoid( | |
| reference).transpose(0, 1) | |
| # Predykcja offsetów bounding boxów względem punktów referencyjnych | |
| predicted_boxes: torch.Tensor = self.bbox_predictor(obj) | |
| # Dodanie offsetów do punktów referencyjnych (tylko dla współrzędnych środka cx, cy) | |
| predicted_boxes[..., :2] += reference_before_sigmoid | |
| # Konwersja z logitów na wartości [0, 1] przez sigmoid | |
| predicted_boxes = predicted_boxes.sigmoid() | |
| return predicted_class_scores, predicted_boxes | |
| def _get_text_classification( | |
| self, | |
| text_obj_tokens_for_batch: List[torch.FloatTensor], | |
| apply_sigmoid: bool = False, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Klasyfikuje teksty jako dialog lub nie-dialog (naracja, efekty dźwiękowe). | |
| Używa klasyfikatora binarnego na tokenach tekstowych do określenia | |
| czy dany tekst to dialog postaci czy inny typ tekstu (naracja, onomatopeje). | |
| Args: | |
| text_obj_tokens_for_batch: Lista tensorów tokenów tekstowych, | |
| jeden tensor dla każdego obrazu w batchu | |
| apply_sigmoid: Czy aplikować sigmoid do outputu (konwersja logitów na prawdop.) | |
| Returns: | |
| Lista tensorów klasyfikacji, jeden dla każdego obrazu. | |
| Każdy tensor ma kształt [num_texts] z wartościami logitów lub prawdopodobieństw. | |
| """ | |
| assert not self.config.disable_detections | |
| is_this_text_a_dialogue: List[torch.Tensor] = [] | |
| for text_obj_tokens in text_obj_tokens_for_batch: | |
| # Jeśli brak tekstów, zwróć pusty tensor | |
| if text_obj_tokens.shape[0] == 0: | |
| is_this_text_a_dialogue.append( | |
| torch.tensor([], dtype=torch.bool)) | |
| continue | |
| # Klasyfikacja każdego tekstu (output: [num_texts, 1] -> squeeze -> [num_texts]) | |
| classification: torch.Tensor = self.is_this_text_a_dialogue( | |
| text_obj_tokens).squeeze(-1) | |
| if apply_sigmoid: | |
| classification = classification.sigmoid() | |
| is_this_text_a_dialogue.append(classification) | |
| return is_this_text_a_dialogue | |
| def _get_character_character_affinity_matrices( | |
| self, | |
| character_obj_tokens_for_batch: Optional[List[torch.FloatTensor]] = None, | |
| crop_embeddings_for_batch: Optional[List[torch.FloatTensor]] = None, | |
| c2c_tokens_for_batch: Optional[List[torch.FloatTensor]] = None, | |
| crop_only: bool = False, | |
| apply_sigmoid: bool = True, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Oblicza macierze podobieństwa między parami postaci (character-character affinity). | |
| Macierze określają prawdopodobieństwo, że dwie postaci to ta sama osoba. | |
| Używane do clusteringu postaci w obrębie strony i między stronami. | |
| Metoda działa w dwóch trybach: | |
| 1. crop_only=True: podobieństwo oparte tylko na embeddingach wizualnych (cosine similarity) | |
| 2. crop_only=False: podobieństwo oparte na tokenach + embeddingach + tokenie c2c | |
| Args: | |
| character_obj_tokens_for_batch: Lista tokenów postaci dla każdego obrazu | |
| crop_embeddings_for_batch: Lista embeddingów wizualnych postaci dla każdego obrazu | |
| c2c_tokens_for_batch: Lista tokenów c2c dla każdego obrazu | |
| crop_only: Czy użyć tylko embeddingów wizualnych (bez tokenów i c2c) | |
| apply_sigmoid: Czy aplikować sigmoid do scorów (konwersja logitów na prawdop.) | |
| Returns: | |
| Lista macierzy podobieństwa, jedna dla każdego obrazu. | |
| Każda macierz ma kształt [num_characters, num_characters] symetryczna. | |
| Wartości w [0,1] jeśli apply_sigmoid=True, logity w przeciwnym razie. | |
| """ | |
| assert self.config.disable_detections or ( | |
| character_obj_tokens_for_batch is not None and c2c_tokens_for_batch is not None) | |
| assert self.config.disable_crop_embeddings or crop_embeddings_for_batch is not None | |
| assert not self.config.disable_detections or not self.config.disable_crop_embeddings | |
| # Tryb crop_only: podobieństwo oparte tylko na cosine similarity embeddingów | |
| if crop_only: | |
| affinity_matrices: List[torch.Tensor] = [] | |
| for crop_embeddings in crop_embeddings_for_batch: | |
| # Normalizacja embeddingów do jednostkowej długości | |
| crop_embeddings_normalized: torch.Tensor = crop_embeddings / \ | |
| crop_embeddings.norm(dim=-1, keepdim=True) | |
| # Cosine similarity: iloczyn skalarny znormalizowanych wektorów | |
| affinity_matrix: torch.Tensor = crop_embeddings_normalized @ crop_embeddings_normalized.T | |
| affinity_matrices.append(affinity_matrix) | |
| return affinity_matrices | |
| # Tryb pełny: podobieństwo z tokenów + embeddingów + tokenu c2c | |
| affinity_matrices: List[torch.Tensor] = [] | |
| for batch_index, (character_obj_tokens, c2c) in enumerate(zip(character_obj_tokens_for_batch, c2c_tokens_for_batch)): | |
| # Jeśli brak postaci, zwróć pustą macierz | |
| if character_obj_tokens.shape[0] == 0: | |
| affinity_matrices.append(torch.zeros( | |
| 0, 0).type_as(character_obj_tokens)) | |
| continue | |
| # Konkatenacja tokenów z embeddingami (jeśli dostępne) | |
| if not self.config.disable_crop_embeddings: | |
| crop_embeddings: torch.Tensor = crop_embeddings_for_batch[batch_index] | |
| assert character_obj_tokens.shape[0] == crop_embeddings.shape[0] | |
| character_obj_tokens = torch.cat( | |
| [character_obj_tokens, crop_embeddings], dim=-1) | |
| # Tworzenie par (i, j) wszystkich postaci dla obliczenia podobieństwa | |
| # char_i: każda postać i powtórzona num_characters razy | |
| char_i: torch.Tensor = repeat(character_obj_tokens, "i d -> i repeat d", | |
| repeat=character_obj_tokens.shape[0]) | |
| # char_j: wszystkie postaci j powtórzone dla każdej postaci i | |
| char_j: torch.Tensor = repeat(character_obj_tokens, "j d -> repeat j d", | |
| repeat=character_obj_tokens.shape[0]) | |
| # Konkatenacja par: [char_i, char_j] -> [num_pairs, 2*hidden_dim] | |
| char_ij: torch.Tensor = rearrange( | |
| [char_i, char_j], "two i j d -> (i j) (two d)") | |
| # Dodanie tokenu c2c do każdej pary (kontekst globalny dla matching) | |
| c2c_repeated: torch.Tensor = repeat( | |
| c2c, "d -> repeat d", repeat=char_ij.shape[0]) | |
| char_ij_c2c: torch.Tensor = torch.cat( | |
| [char_ij, c2c_repeated], dim=-1) | |
| # Predykcja scorów podobieństwa przez MLP head | |
| character_character_affinities: torch.Tensor = self.character_character_matching_head( | |
| char_ij_c2c) | |
| # Reshape z [num_pairs, 1] na macierz [num_characters, num_characters] | |
| character_character_affinities = rearrange( | |
| character_character_affinities, "(i j) 1 -> i j", i=char_i.shape[0]) | |
| # Wymuszenie symetryczności macierzy (score(i,j) = score(j,i)) | |
| character_character_affinities = ( | |
| character_character_affinities + character_character_affinities.T) / 2 | |
| if apply_sigmoid: | |
| character_character_affinities = character_character_affinities.sigmoid() | |
| affinity_matrices.append(character_character_affinities) | |
| return affinity_matrices | |
| def _get_text_character_affinity_matrices( | |
| self, | |
| character_obj_tokens_for_batch: Optional[List[torch.FloatTensor]] = None, | |
| text_obj_tokens_for_this_batch: Optional[List[torch.FloatTensor]] = None, | |
| t2c_tokens_for_batch: Optional[List[torch.FloatTensor]] = None, | |
| apply_sigmoid: bool = True, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Oblicza macierze podobieństwa między tekstami a postaciami (speaker assignment). | |
| Dla każdej pary (tekst, postać) oblicza prawdopodobieństwo, że dany tekst | |
| jest wypowiadany przez daną postać. Używane do przypisywania dialogów do mówiących. | |
| Args: | |
| character_obj_tokens_for_batch: Lista tokenów postaci dla każdego obrazu | |
| text_obj_tokens_for_this_batch: Lista tokenów tekstów dla każdego obrazu | |
| t2c_tokens_for_batch: Lista tokenów t2c dla każdego obrazu | |
| apply_sigmoid: Czy aplikować sigmoid do scorów (konwersja logitów na prawdop.) | |
| Returns: | |
| Lista macierzy podobieństwa, jedna dla każdego obrazu. | |
| Każda macierz ma kształt [num_texts, num_characters]. | |
| Wartość macierzy[i][j] = prawdopodobieństwo, że tekst i należy do postaci j. | |
| """ | |
| assert not self.config.disable_detections | |
| assert character_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None and t2c_tokens_for_batch is not None | |
| affinity_matrices: List[torch.Tensor] = [] | |
| for character_obj_tokens, text_obj_tokens, t2c in zip(character_obj_tokens_for_batch, text_obj_tokens_for_this_batch, t2c_tokens_for_batch): | |
| # Jeśli brak tekstów lub postaci, zwróć pustą macierz | |
| if character_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0: | |
| affinity_matrices.append(torch.zeros( | |
| text_obj_tokens.shape[0], character_obj_tokens.shape[0]).type_as(character_obj_tokens)) | |
| continue | |
| # Tworzenie par (text_i, character_j) dla wszystkich kombinacji | |
| # text_i: każdy tekst i powtórzony num_characters razy | |
| text_i: torch.Tensor = repeat(text_obj_tokens, "i d -> i repeat d", | |
| repeat=character_obj_tokens.shape[0]) | |
| # char_j: wszystkie postaci j powtórzone dla każdego tekstu i | |
| char_j: torch.Tensor = repeat(character_obj_tokens, "j d -> repeat j d", | |
| repeat=text_obj_tokens.shape[0]) | |
| # Konkatenacja par: [text_i, char_j] -> [num_pairs, 2*hidden_dim] | |
| text_char: torch.Tensor = rearrange( | |
| [text_i, char_j], "two i j d -> (i j) (two d)") | |
| # Dodanie tokenu t2c do każdej pary (kontekst globalny dla text-character matching) | |
| t2c_repeated: torch.Tensor = repeat( | |
| t2c, "d -> repeat d", repeat=text_char.shape[0]) | |
| text_char_t2c: torch.Tensor = torch.cat( | |
| [text_char, t2c_repeated], dim=-1) | |
| # Predykcja scorów podobieństwa przez MLP head | |
| text_character_affinities: torch.Tensor = self.text_character_matching_head( | |
| text_char_t2c) | |
| # Reshape z [num_pairs, 1] na macierz [num_texts, num_characters] | |
| text_character_affinities = rearrange( | |
| text_character_affinities, "(i j) 1 -> i j", i=text_i.shape[0]) | |
| if apply_sigmoid: | |
| text_character_affinities = text_character_affinities.sigmoid() | |
| affinity_matrices.append(text_character_affinities) | |
| return affinity_matrices | |
| def _get_text_tail_affinity_matrices( | |
| self, | |
| text_obj_tokens_for_this_batch: Optional[List[torch.FloatTensor]] = None, | |
| tail_obj_tokens_for_batch: Optional[List[torch.FloatTensor]] = None, | |
| apply_sigmoid: bool = True, | |
| ) -> List[torch.Tensor]: | |
| """ | |
| Oblicza macierze podobieństwa między tekstami a ogonami dymków. | |
| Dla każdej pary (tekst, ogon) oblicza prawdopodobieństwo, że dany tekst | |
| należy do danego ogona dymku. Używane do łączenia tekstów z dymkami dialogowymi. | |
| Args: | |
| text_obj_tokens_for_this_batch: Lista tokenów tekstów dla każdego obrazu | |
| tail_obj_tokens_for_batch: Lista tokenów ogonów dla każdego obrazu | |
| apply_sigmoid: Czy aplikować sigmoid do scorów (konwersja logitów na prawdop.) | |
| Returns: | |
| Lista macierzy podobieństwa, jedna dla każdego obrazu. | |
| Każda macierz ma kształt [num_texts, num_tails]. | |
| Wartość macierzy[i][j] = prawdopodobieństwo, że tekst i należy do ogona j. | |
| """ | |
| assert not self.config.disable_detections | |
| assert tail_obj_tokens_for_batch is not None and text_obj_tokens_for_this_batch is not None | |
| affinity_matrices: List[torch.Tensor] = [] | |
| for tail_obj_tokens, text_obj_tokens in zip(tail_obj_tokens_for_batch, text_obj_tokens_for_this_batch): | |
| # Jeśli brak tekstów lub ogonów, zwróć pustą macierz | |
| if tail_obj_tokens.shape[0] == 0 or text_obj_tokens.shape[0] == 0: | |
| affinity_matrices.append(torch.zeros( | |
| text_obj_tokens.shape[0], tail_obj_tokens.shape[0]).type_as(tail_obj_tokens)) | |
| continue | |
| # Tworzenie par (text_i, tail_j) dla wszystkich kombinacji | |
| # text_i: każdy tekst i powtórzony num_tails razy | |
| text_i: torch.Tensor = repeat(text_obj_tokens, "i d -> i repeat d", | |
| repeat=tail_obj_tokens.shape[0]) | |
| # tail_j: wszystkie ogony j powtórzone dla każdego tekstu i | |
| tail_j: torch.Tensor = repeat(tail_obj_tokens, "j d -> repeat j d", | |
| repeat=text_obj_tokens.shape[0]) | |
| # Konkatenacja par: [text_i, tail_j] -> [num_pairs, 2*hidden_dim] | |
| text_tail: torch.Tensor = rearrange( | |
| [text_i, tail_j], "two i j d -> (i j) (two d)") | |
| # Predykcja scorów podobieństwa przez MLP head (bez dodatkowego tokenu kontekstu) | |
| text_tail_affinities: torch.Tensor = self.text_tail_matching_head( | |
| text_tail) | |
| # Reshape z [num_pairs, 1] na macierz [num_texts, num_tails] | |
| text_tail_affinities = rearrange( | |
| text_tail_affinities, "(i j) 1 -> i j", i=text_i.shape[0]) | |
| if apply_sigmoid: | |
| text_tail_affinities = text_tail_affinities.sigmoid() | |
| affinity_matrices.append(text_tail_affinities) | |
| return affinity_matrices | |
| # ============================================================================ | |
| # FUNKCJE POMOCNICZE (skopiowane z transformers.models.detr) | |
| # ============================================================================ | |
| # Copied from transformers.models.detr.modeling_detr._upcast | |
| def _upcast(t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Konwertuje tensor na typ o wyższej precyzji aby uniknąć overflow podczas mnożeń. | |
| Args: | |
| t: Tensor do konwersji | |
| Returns: | |
| Tensor skonwertowany na float32/float64 (dla float) lub int32/int64 (dla int) | |
| """ | |
| # Chroni przed overflow numerycznym przez upcasting do równoważnego typu wyższej precyzji | |
| if t.is_floating_point(): | |
| return t if t.dtype in (torch.float32, torch.float64) else t.float() | |
| else: | |
| return t if t.dtype in (torch.int32, torch.int64) else t.int() | |
| # Copied from transformers.models.detr.modeling_detr.box_area | |
| def box_area(boxes: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Oblicza pole powierzchni dla zestawu bounding boxów w formacie (x1, y1, x2, y2). | |
| Args: | |
| boxes: Tensor z bounding boxami o kształcie [num_boxes, 4]. | |
| Oczekiwany format: (x1, y1, x2, y2) gdzie 0 <= x1 < x2 i 0 <= y1 < y2. | |
| Returns: | |
| Tensor zawierający pole powierzchni dla każdego boxa [num_boxes] | |
| """ | |
| boxes = _upcast(boxes) | |
| return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) | |
| # Copied from transformers.models.detr.modeling_detr.box_iou | |
| def box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Oblicza IoU (Intersection over Union) między dwoma zestawami bounding boxów. | |
| Args: | |
| boxes1: Pierwszy zestaw boxów [N, 4] w formacie (x1, y1, x2, y2) | |
| boxes2: Drugi zestaw boxów [M, 4] w formacie (x1, y1, x2, y2) | |
| Returns: | |
| Krotka (iou, union): | |
| - iou: Macierz IoU [N, M] gdzie iou[i][j] = IoU między boxes1[i] a boxes2[j] | |
| - union: Macierz pól unii [N, M] | |
| """ | |
| area1: torch.Tensor = box_area(boxes1) | |
| area2: torch.Tensor = box_area(boxes2) | |
| # Obliczenie współrzędnych przecięcia (intersection) | |
| left_top: torch.Tensor = torch.max( | |
| boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] | |
| right_bottom: torch.Tensor = torch.min( | |
| boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] | |
| # Szerokość i wysokość przecięcia (clamp min=0 dla braku przecięcia) | |
| width_height: torch.Tensor = ( | |
| right_bottom - left_top).clamp(min=0) # [N,M,2] | |
| inter: torch.Tensor = width_height[:, :, | |
| 0] * width_height[:, :, 1] # [N,M] | |
| # Union = pole1 + pole2 - przecięcie | |
| union: torch.Tensor = area1[:, None] + area2 - inter | |
| # IoU = przecięcie / unia | |
| iou: torch.Tensor = inter / union | |
| return iou, union | |
| # Copied from transformers.models.detr.modeling_detr.generalized_box_iou | |
| def generalized_box_iou(boxes1: torch.Tensor, boxes2: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Oblicza Generalized IoU (GIoU) między dwoma zestawami bounding boxów. | |
| GIoU rozszerza klasyczne IoU przez uwzględnienie najmniejszego obejmującego | |
| prostokąta (smallest enclosing box). GIoU ∈ [-1, 1], gdzie wyższe wartości | |
| oznaczają lepsze dopasowanie. W przeciwieństwie do IoU, GIoU może być ujemne | |
| gdy boxy się nie przecinają. | |
| Więcej: https://giou.stanford.edu/ | |
| Args: | |
| boxes1: Pierwszy zestaw boxów [N, 4] w formacie corners (x0, y0, x1, y1) | |
| boxes2: Drugi zestaw boxów [M, 4] w formacie corners (x0, y0, x1, y1) | |
| Returns: | |
| Macierz GIoU [N, M] gdzie giou[i][j] = GIoU między boxes1[i] a boxes2[j] | |
| Raises: | |
| ValueError: Jeśli boxy nie są w poprawnym formacie (x0 < x1, y0 < y1) | |
| """ | |
| # Walidacja formatu boksów (zdegenerowane boksy dają inf/nan) | |
| if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): | |
| raise ValueError( | |
| f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") | |
| if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): | |
| raise ValueError( | |
| f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") | |
| # Obliczenie standardowego IoU i unii | |
| iou: torch.Tensor | |
| union: torch.Tensor | |
| iou, union = box_iou(boxes1, boxes2) | |
| # Obliczenie najmniejszego obejmującego prostokąta (enclosing box) | |
| top_left: torch.Tensor = torch.min(boxes1[:, None, :2], boxes2[:, :2]) | |
| bottom_right: torch.Tensor = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) | |
| # Pole najmniejszego obejmującego prostokąta | |
| width_height: torch.Tensor = ( | |
| bottom_right - top_left).clamp(min=0) # [N,M,2] | |
| area: torch.Tensor = width_height[:, :, 0] * width_height[:, :, 1] | |
| # GIoU = IoU - (pole_obejmujące - unia) / pole_obejmujące | |
| return iou - (area - union) / area | |
| # Copied from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrHungarianMatcher with DeformableDetr->ConditionalDetr | |
| class ConditionalDetrHungarianMatcher(nn.Module): | |
| """ | |
| Hungarian Matcher - przypisanie predykcji do targetów metodą węgierską. | |
| Klasa oblicza optymalne dopasowanie 1-do-1 między predykcjami modelu a ground truth | |
| targets używając algorytmu węgierskiego (Hungarian algorithm). Jest używana podczas | |
| treningu do określenia, która predykcja odpowiada któremu obiektowi ground truth. | |
| Ze względów wydajnościowych, targets nie zawierają klasy "no_object". W efekcie | |
| zazwyczaj jest więcej predykcji niż targetów. W takim przypadku wykonujemy dopasowanie | |
| 1-do-1 dla najlepszych predykcji, a pozostałe są niedopasowane (traktowane jako non-objects). | |
| Koszt dopasowania (matching cost) składa się z trzech komponentów: | |
| 1. class_cost: koszt błędu klasyfikacji (focal loss) | |
| 2. bbox_cost: koszt błędu L1 współrzędnych bounding boxa | |
| 3. giou_cost: koszt negatywnego GIoU między bounding boxami | |
| Attributes: | |
| class_cost: Względna waga błędu klasyfikacji w koszcie dopasowania | |
| bbox_cost: Względna waga błędu L1 współrzędnych bbox w koszcie dopasowania | |
| giou_cost: Względna waga GIoU loss bbox w koszcie dopasowania | |
| """ | |
| def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1) -> None: | |
| """ | |
| Inicjalizuje Hungarian Matcher z wagami kosztów. | |
| Args: | |
| class_cost: Waga kosztu klasyfikacji (domyślnie 1.0) | |
| bbox_cost: Waga kosztu L1 bbox (domyślnie 1.0) | |
| giou_cost: Waga kosztu GIoU (domyślnie 1.0) | |
| Raises: | |
| ValueError: Jeśli wszystkie koszty są zerowe (brak funkcji kosztu) | |
| """ | |
| super().__init__() | |
| self.class_cost: float = class_cost | |
| self.bbox_cost: float = bbox_cost | |
| self.giou_cost: float = giou_cost | |
| if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: | |
| raise ValueError("All costs of the Matcher can't be 0") | |
| def forward(self, outputs: Dict[str, torch.Tensor], targets: List[Dict[str, torch.Tensor]]) -> List[Tuple[torch.Tensor, torch.Tensor]]: | |
| """ | |
| Wykonuje dopasowanie węgierskie między predykcjami a ground truth targets. | |
| Oblicza macierz kosztów dla wszystkich par (predykcja, target) składającą się z: | |
| 1. Focal loss dla klasyfikacji (alpha=0.25, gamma=2.0) | |
| 2. L1 distance między współrzędnymi bbox (format center) | |
| 3. Negatywny GIoU między bbox (format corners) | |
| Następnie używa algorytmu węgierskiego (linear_sum_assignment) do znalezienia | |
| optymalnego dopasowania 1-do-1 minimalizującego całkowity koszt dla każdego | |
| przykładu w batchu. | |
| Args: | |
| outputs: Słownik zawierający predykcje modelu: | |
| - "logits": torch.Tensor [batch_size, num_queries, num_classes] | |
| Logity klasyfikacji dla wszystkich queries | |
| - "pred_boxes": torch.Tensor [batch_size, num_queries, 4] | |
| Predykcje bounding boxów w formacie center (cx, cy, w, h) | |
| targets: Lista słowników (len=batch_size), każdy target zawiera: | |
| - "class_labels": torch.Tensor [num_target_boxes] | |
| Ground truth etykiety klas dla obiektów w obrazie | |
| - "boxes": torch.Tensor [num_target_boxes, 4] | |
| Ground truth bounding boxy w formacie center | |
| Returns: | |
| Lista krotek (len=batch_size), każda krotka to (index_i, index_j): | |
| - index_i: torch.Tensor [min(num_queries, num_target_boxes)] | |
| Indeksy wybranych predykcji (w kolejności) | |
| - index_j: torch.Tensor [min(num_queries, num_target_boxes)] | |
| Indeksy odpowiadających im targetów (w kolejności) | |
| Dla każdego elementu batcha: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) | |
| Note: | |
| Metoda oznaczona @torch.no_grad() - nie obliczamy gradientów dla matchingu | |
| (dopasowanie służy tylko do określenia, które predykcje trenować względem | |
| których targetów, nie uczestniczy w backpropagation). | |
| """ | |
| batch_size, num_queries = outputs["logits"].shape[:2] | |
| # Spłaszczamy tensory aby obliczyć macierze kosztów w batch | |
| # Kształt: [batch_size * num_queries, num_classes] | |
| out_prob: torch.Tensor = outputs["logits"].flatten(0, 1).sigmoid() | |
| # Kształt: [batch_size * num_queries, 4] | |
| out_bbox: torch.Tensor = outputs["pred_boxes"].flatten(0, 1) | |
| # Konkatenujemy również etykiety i boxy targetów ze wszystkich przykładów w batchu | |
| target_ids: torch.Tensor = torch.cat( | |
| [v["class_labels"] for v in targets]) | |
| target_bbox: torch.Tensor = torch.cat([v["boxes"] for v in targets]) | |
| # Obliczamy koszt klasyfikacji używając focal loss | |
| # Focal loss daje większą wagę trudnym przykładom (alpha=0.25, gamma=2.0) | |
| alpha: float = 0.25 | |
| gamma: float = 2.0 | |
| # Koszt dla negatywnej klasy (predykcja tła gdy target to obiekt) | |
| neg_cost_class: torch.Tensor = (1 - alpha) * (out_prob**gamma) * \ | |
| (-(1 - out_prob + 1e-8).log()) | |
| # Koszt dla pozytywnej klasy (predykcja obiektu gdy target to obiekt) | |
| pos_cost_class: torch.Tensor = alpha * \ | |
| ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) | |
| # Finalna macierz kosztów klasyfikacji: różnica między kosztem pos i neg | |
| # Kształt: [batch_size * num_queries, num_total_targets] | |
| class_cost: torch.Tensor = pos_cost_class[:, target_ids] - \ | |
| neg_cost_class[:, target_ids] | |
| # Obliczamy koszt L1 między bounding boxami | |
| # cdist oblicza parwise distance z normą L1 (Manhattan distance) | |
| # Kształt: [batch_size * num_queries, num_total_targets] | |
| bbox_cost: torch.Tensor = torch.cdist(out_bbox, target_bbox, p=1) | |
| # Obliczamy koszt GIoU między bounding boxami | |
| # Najpierw konwertujemy z formatu center (cx, cy, w, h) do corners (x1, y1, x2, y2) | |
| # GIoU jest negowany bo chcemy minimalizować koszt (wyższy GIoU = lepsze dopasowanie) | |
| # Kształt: [batch_size * num_queries, num_total_targets] | |
| giou_cost: torch.Tensor = -generalized_box_iou(center_to_corners_format( | |
| out_bbox), center_to_corners_format(target_bbox)) | |
| # Finalna macierz kosztów - ważona suma trzech komponentów | |
| # Kształt: [batch_size * num_queries, num_total_targets] | |
| cost_matrix: torch.Tensor = self.bbox_cost * bbox_cost + \ | |
| self.class_cost * class_cost + self.giou_cost * giou_cost | |
| # Przekształcamy z powrotem do kształtu [batch_size, num_queries, num_total_targets] | |
| # i przenosimy do CPU (linear_sum_assignment wymaga CPU) | |
| cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() | |
| # Rozdzielamy macierz kosztów dla każdego przykładu w batchu | |
| # sizes zawiera liczbę targetów dla każdego przykładu | |
| sizes: List[int] = [len(v["boxes"]) for v in targets] | |
| # Dla każdego przykładu wykonujemy algorytm węgierski (linear_sum_assignment) | |
| # który znajduje optymalne dopasowanie minimalizujące całkowity koszt | |
| indices: List[Tuple[NDArray, NDArray]] = [linear_sum_assignment(c[i]) for i, c in enumerate( | |
| cost_matrix.split(sizes, -1))] | |
| # Konwertujemy numpy arrays na torch tensors i zwracamy listę krotek (pred_idx, target_idx) | |
| return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] | |