Spaces:
Paused
Paused
| from typing import List | |
| import torch | |
| from transformers import AutoTokenizer, UMT5EncoderModel | |
| from toolkit.models.loaders.comfy import get_comfy_path | |
| def get_umt5_encoder( | |
| model_path: str, | |
| tokenizer_subfolder: str = None, | |
| encoder_subfolder: str = None, | |
| torch_dtype: str = torch.bfloat16, | |
| comfy_files: List[str] = [ | |
| "text_encoders/umt5_xxl_fp16.safetensors", | |
| "text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", | |
| ], | |
| ) -> UMT5EncoderModel: | |
| """ | |
| Load the UMT5 encoder model from the specified path. | |
| """ | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, subfolder=tokenizer_subfolder) | |
| comfy_path = get_comfy_path(comfy_files) | |
| comfy_path = None | |
| if comfy_path is not None: | |
| text_encoder = UMT5EncoderModel.from_single_file( | |
| comfy_path, torch_dtype=torch_dtype | |
| ) | |
| else: | |
| print(f"Using {model_path} for UMT5 encoder.") | |
| text_encoder = UMT5EncoderModel.from_pretrained( | |
| model_path, subfolder=encoder_subfolder, torch_dtype=torch_dtype | |
| ) | |
| return tokenizer, text_encoder | |