Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2022-present NAVER Corp. | |
| Copyright (c) 2025 Bytedance Ltd. and/or its affiliates. | |
| MIT License | |
| This file has been modified by [ByteDance Ltd. and/or its affiliates] on 20250118. | |
| The original file available at https://github.com/clovaai/donut/blob/master/donut/model.py was released under the MIT license. | |
| This modified file is released under the same license. | |
| """ | |
| import logging | |
| from collections import defaultdict | |
| from typing import List, Optional | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from timm.models.swin_transformer import SwinTransformer | |
| from torch import nn | |
| from transformers import ( | |
| MBartConfig, | |
| MBartForCausalLM, | |
| StoppingCriteria, | |
| StoppingCriteriaList, | |
| ) | |
| from transformers.file_utils import ModelOutput | |
| from transformers.modeling_utils import PretrainedConfig, PreTrainedModel | |
| class SwinEncoder(nn.Module): | |
| r""" | |
| Encoder based on SwinTransformer | |
| Set the initial weights and configuration with a pretrained SwinTransformer and then | |
| modify the detailed configurations | |
| Args: | |
| input_size: Input image size (width, height) | |
| align_long_axis: Whether to rotate image if height is greater than width | |
| window_size: Window size(=patch size) of SwinTransformer | |
| encoder_layer: Number of layers of SwinTransformer encoder | |
| name_or_path: Name of a pretrained model name either registered in huggingface.co. or saved in local. | |
| otherwise, `swin_base_patch4_window12_384` will be set (using `timm`). | |
| """ | |
| def __init__( | |
| self, | |
| input_size, | |
| align_long_axis: bool = False, | |
| window_size: int = 7, | |
| encoder_layer: List[int] = [2, 2, 14, 2], | |
| patch_size: int = [4, 4], | |
| embed_dim: int = 128, | |
| num_heads: List[int] = [4, 8, 16, 32], | |
| ): | |
| super().__init__() | |
| if isinstance(input_size, int): | |
| input_size = [input_size, input_size] | |
| self.input_size = input_size | |
| self.align_long_axis = align_long_axis | |
| self.window_size = window_size | |
| self.encoder_layer = encoder_layer | |
| self.patch_size = patch_size | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.model = SwinTransformer( | |
| img_size=self.input_size, | |
| depths=self.encoder_layer, | |
| window_size=self.window_size, | |
| patch_size=self.patch_size, | |
| embed_dim=self.embed_dim, | |
| num_heads=self.num_heads, | |
| num_classes=0, | |
| ) | |
| def forward(self, x: torch.Tensor, text_embedding: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (batch_size, num_channels, height, width) | |
| """ | |
| x = self.model.patch_embed(x) | |
| x = self.model.pos_drop(x) | |
| x = self.model.layers(x) | |
| return x | |
| class LayerNorm(nn.LayerNorm): | |
| """Subclass torch's LayerNorm to handle fp16.""" | |
| def _set_dtype(self, dtype): | |
| self._dtype = dtype | |
| def forward(self, x: torch.Tensor): | |
| orig_type = x.dtype | |
| ret = super().forward(x.type(dtype=self._dtype)) | |
| return ret.type(orig_type) | |
| class BARTDecoder(nn.Module): | |
| """ | |
| Decoder based on Multilingual BART | |
| Set the initial weights and configuration with a pretrained multilingual BART model, | |
| and modify the detailed configurations as a Donut decoder | |
| Args: | |
| decoder_layer: | |
| Number of layers of BARTDecoder | |
| max_position_embeddings: | |
| The maximum sequence length to be trained | |
| name_or_path: | |
| Name of a pretrained model name either registered in huggingface.co. or saved in local, | |
| otherwise, `facebook/mbart-large-50` will be set (using `transformers`) | |
| """ | |
| def __init__( | |
| self, | |
| tokenizer, | |
| decoder_layer: int, | |
| max_position_embeddings: int, | |
| hidden_dimension: int = 1024, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.decoder_layer = decoder_layer | |
| self.max_position_embeddings = max_position_embeddings | |
| self.hidden_dimension = hidden_dimension | |
| self.tokenizer = tokenizer | |
| self.model = MBartForCausalLM( | |
| config=MBartConfig( | |
| tie_word_embeddings=True, | |
| is_decoder=True, | |
| is_encoder_decoder=False, | |
| add_cross_attention=True, | |
| decoder_layers=self.decoder_layer, | |
| max_position_embeddings=self.max_position_embeddings, | |
| vocab_size=len(self.tokenizer), | |
| scale_embedding=True, | |
| add_final_layer_norm=True, | |
| d_model=self.hidden_dimension, | |
| ) | |
| ) | |
| # self.model.config.is_encoder_decoder = True # to get cross-attention | |
| self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id | |
| self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference | |
| def add_special_tokens(self, list_of_tokens: List[str]): | |
| """ | |
| Add special tokens to tokenizer and resize the token embeddings | |
| """ | |
| newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))}) | |
| if newly_added_num > 0: | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| def add_tokens(self, list_of_tokens: List[str]): | |
| """ | |
| Add special tokens to tokenizer and resize the token embeddings | |
| """ | |
| newly_added_num = self.tokenizer.add_tokens(sorted(set(list_of_tokens))) | |
| if newly_added_num > 0: | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| def prepare_inputs_for_inference( | |
| self, | |
| input_ids: torch.Tensor, | |
| encoder_outputs: torch.Tensor, | |
| past=None, | |
| past_key_values=None, | |
| use_cache: bool = None, | |
| attention_mask: torch.Tensor = None, | |
| **kwargs, | |
| ): | |
| """ | |
| Args: | |
| input_ids: (batch_size, sequence_length) | |
| Returns: | |
| input_ids: (batch_size, sequence_length) | |
| attention_mask: (batch_size, sequence_length) | |
| encoder_hidden_states: (batch_size, sequence_length, embedding_dim) | |
| """ | |
| attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long() | |
| past = past or past_key_values | |
| if past is not None: | |
| input_ids = input_ids[:, -1:] | |
| output = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "past_key_values": past, | |
| "use_cache": use_cache, | |
| "encoder_hidden_states": encoder_outputs.last_hidden_state, | |
| } | |
| return output | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| past_key_values: Optional[torch.Tensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| labels: Optional[torch.Tensor] = None, | |
| use_cache: bool = None, | |
| output_attentions: Optional[torch.Tensor] = None, | |
| output_hidden_states: Optional[torch.Tensor] = None, | |
| return_dict: bool = None, | |
| ): | |
| return self.model.forward( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| encoder_hidden_states=encoder_hidden_states, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| use_cache=use_cache, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor: | |
| """ | |
| Resize position embeddings | |
| Truncate if sequence length of MBart backbone is greater than given max_length, | |
| else interpolate to max_length | |
| """ | |
| if weight.shape[0] > max_length: | |
| weight = weight[:max_length, ...] | |
| else: | |
| weight = ( | |
| F.interpolate( | |
| weight.permute(1, 0).unsqueeze(0), | |
| size=max_length, | |
| mode="linear", | |
| align_corners=False, | |
| ) | |
| .squeeze(0) | |
| .permute(1, 0) | |
| ) | |
| return weight | |
| class DonutConfig(PretrainedConfig): | |
| def __init__( | |
| self, | |
| decoder_layer: int = 10, | |
| max_position_embeddings: int = None, | |
| max_length: int = 4096, | |
| hidden_dimension: int = 1024, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.decoder_layer = decoder_layer | |
| self.max_position_embeddings = max_length if max_position_embeddings is None else max_position_embeddings | |
| self.max_length = max_length | |
| self.hidden_dimension = hidden_dimension | |
| class RunningVarTorch: | |
| def __init__(self, L=15, norm=False): | |
| self.values = None | |
| self.L = L | |
| self.norm = norm | |
| def push(self, x: torch.Tensor): | |
| assert x.dim() == 1 | |
| if self.values is None: | |
| self.values = x[:, None] | |
| elif self.values.shape[1] < self.L: | |
| self.values = torch.cat((self.values, x[:, None]), 1) | |
| else: | |
| self.values = torch.cat((self.values[:, 1:], x[:, None]), 1) | |
| def variance(self): | |
| if self.values is None: | |
| return | |
| if self.norm: | |
| return torch.var(self.values, 1) / self.values.shape[1] | |
| else: | |
| return torch.var(self.values, 1) | |
| class StoppingCriteriaScores(StoppingCriteria): | |
| def __init__(self, threshold: float = 0.015, window_size: int = 200): | |
| super().__init__() | |
| self.threshold = threshold | |
| self.vars = RunningVarTorch(norm=True) | |
| self.varvars = RunningVarTorch(L=window_size) | |
| self.stop_inds = defaultdict(int) | |
| self.stopped = defaultdict(bool) | |
| self.size = 0 | |
| self.window_size = window_size | |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): | |
| last_scores = scores[-1] | |
| self.vars.push(last_scores.max(1)[0].float().cpu()) | |
| self.varvars.push(self.vars.variance()) | |
| self.size += 1 | |
| if self.size < self.window_size: | |
| return False | |
| varvar = self.varvars.variance() | |
| for b in range(len(last_scores)): | |
| if varvar[b] < self.threshold: | |
| if self.stop_inds[b] > 0 and not self.stopped[b]: | |
| self.stopped[b] = self.stop_inds[b] >= self.size | |
| else: | |
| self.stop_inds[b] = int(min(max(self.size, 1) * 1.15 + 150 + self.window_size, 4095)) | |
| else: | |
| self.stop_inds[b] = 0 | |
| self.stopped[b] = False | |
| return all(self.stopped.values()) and len(self.stopped) > 0 | |
| def batch(l, b=15): | |
| subs = [] | |
| for i in range(len(l) - b): | |
| subs.append(l[i : i + b]) | |
| return subs | |
| def subdiv(l, b=10): | |
| subs = [] | |
| for i in range(len(l) - b): | |
| subs.append(l[: i + b]) | |
| return subs | |
| class DonutModel(PreTrainedModel): | |
| config_class = DonutConfig | |
| base_model_prefix = "donut" | |
| def __init__(self, config: DonutConfig, vision_tower=None, tokenizer=None): | |
| super().__init__(config) | |
| self.config = config | |
| self.tokenizer = tokenizer | |
| self.vpm = vision_tower | |
| # build language model | |
| self.llm = BARTDecoder( | |
| tokenizer=tokenizer, | |
| decoder_layer=self.config.decoder_layer, | |
| max_position_embeddings=self.config.max_position_embeddings, | |
| hidden_dimension=self.config.hidden_dimension, | |
| ) | |
| self.ids_to_tokens = {id: content for content, id in self.llm.tokenizer.vocab.items()} | |
| def get_input_embeddings(self, tensor): | |
| return self.llm.model.get_input_embeddings()(tensor) | |
| def forward( | |
| self, | |
| inputs: dict, | |
| ): | |
| image_tensors = inputs["pixel_values"] | |
| input_ids = inputs["input_ids"].contiguous() | |
| attention_mask = inputs["attention_mask"] | |
| labels = inputs["labels"].contiguous() | |
| encoder_outputs = self.vpm( | |
| image_tensors, | |
| text_embedding=self.llm.model.get_input_embeddings()(input_ids), | |
| ) | |
| decoder_outputs = self.llm( | |
| input_ids=input_ids, | |
| encoder_hidden_states=encoder_outputs, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| ) | |
| return decoder_outputs | |
| def get_hidden_states_during_inference( | |
| self, | |
| prompt_ids: torch.Tensor, | |
| image: Image.Image = None, | |
| image_tensors: Optional[torch.Tensor] = None, | |
| ): | |
| if image_tensors is None: | |
| image_tensors = self.vpm.prepare_input(image).unsqueeze(0) | |
| if self.device.type != "mps": | |
| image_tensors = image_tensors.to(next(self.parameters()).dtype) | |
| image_tensors = image_tensors.to(self.device) | |
| prompt_ids = prompt_ids.to(self.device) | |
| all_hidden_states = self.vpm.forward_features( | |
| image_tensors, text_embedding=self.get_input_embeddings(prompt_ids) | |
| ) | |
| return all_hidden_states | |
| def get_attn_weights_during_inference( | |
| self, | |
| prompt_ids: torch.Tensor, | |
| image: Image.Image = None, | |
| image_tensors: Optional[torch.Tensor] = None, | |
| ): | |
| if image_tensors is None: | |
| image_tensors = self.vpm.prepare_input(image).unsqueeze(0) | |
| if self.device.type != "mps": | |
| image_tensors = image_tensors.to(next(self.parameters()).dtype) | |
| image_tensors = image_tensors.to(self.device) | |
| prompt_ids = prompt_ids.to(self.device) | |
| last_attn_score = self.vpm.get_last_layer_cross_attn_score( | |
| image_tensors, text_embedding=self.get_input_embeddings(prompt_ids) | |
| ) | |
| return last_attn_score | |
| def inference( | |
| self, | |
| prompt_ids: torch.Tensor, | |
| image: Image.Image = None, | |
| image_tensors: Optional[torch.Tensor] = None, | |
| return_attentions: bool = False, | |
| early_stopping: bool = True, | |
| ): | |
| """ | |
| Generate a token sequence in an auto-regressive manner. | |
| Args: | |
| image: input document image (PIL.Image) | |
| image_tensors: (1, num_channels, height, width) | |
| convert prompt to tensor if image_tensor is not fed | |
| """ | |
| output = { | |
| "predictions": list(), | |
| "sequences": list(), | |
| "repeats": list(), | |
| "repetitions": list(), | |
| } | |
| if image is None and image_tensors is None: | |
| logging.warn("Image not found") | |
| return output | |
| if image_tensors is None: | |
| image_tensors = self.vpm.prepare_input(image).unsqueeze(0) | |
| if self.device.type != "mps": | |
| image_tensors = image_tensors.to(next(self.parameters()).dtype) | |
| image_tensors = image_tensors.to(self.device) | |
| prompt_ids = prompt_ids.to(self.device) | |
| last_hidden_state = self.vpm(image_tensors, text_embedding=self.get_input_embeddings(prompt_ids)) | |
| encoder_outputs = ModelOutput(last_hidden_state=last_hidden_state, attentions=None) | |
| if len(encoder_outputs.last_hidden_state.size()) == 1: | |
| encoder_outputs.last_hidden_state = encoder_outputs.last_hidden_state.unsqueeze(0) | |
| # get decoder output | |
| decoder_output = self.llm.model.generate( | |
| input_ids=prompt_ids, | |
| encoder_outputs=encoder_outputs, | |
| min_length=1, | |
| max_length=self.config.max_length, | |
| pad_token_id=self.llm.tokenizer.pad_token_id, | |
| eos_token_id=self.llm.tokenizer.eos_token_id, | |
| use_cache=True, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| output_attentions=return_attentions, | |
| do_sample=False, | |
| num_beams=1, | |
| stopping_criteria=StoppingCriteriaList([StoppingCriteriaScores()] if early_stopping else []), | |
| ) | |
| output["repetitions"] = decoder_output.sequences.clone() | |
| output["sequences"] = decoder_output.sequences.clone() | |
| output["scores"] = torch.stack(decoder_output.scores, 1).softmax(-1).cpu().max(-1)[0] | |
| output["repetitions"] = self.llm.tokenizer.batch_decode(output["repetitions"], skip_special_tokens=False) | |
| return output | |