| from typing import Any, Dict | |
| from torch import nn | |
| class TextEmbExtractor(nn.Module): | |
| def __init__(self, tokenizer, text_encoder) -> None: | |
| super(TextEmbExtractor, self).__init__() | |
| self.tokenizer = tokenizer | |
| self.text_encoder = text_encoder | |
| def forward( | |
| self, | |
| texts, | |
| text_params: Dict = None, | |
| ): | |
| if text_params is None: | |
| text_params = {} | |
| special_prompt_input = self.tokenizer( | |
| texts, | |
| max_length=self.tokenizer.model_max_length, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| if ( | |
| hasattr(self.text_encoder.config, "use_attention_mask") | |
| and self.text_encoder.config.use_attention_mask | |
| ): | |
| attention_mask = special_prompt_input.attention_mask.to( | |
| self.text_encoder.device | |
| ) | |
| else: | |
| attention_mask = None | |
| embeddings = self.text_encoder( | |
| special_prompt_input.input_ids.to(self.text_encoder.device), | |
| attention_mask=attention_mask, | |
| **text_params | |
| ) | |
| return embeddings | |