Spaces:
Paused
Paused
| import torch | |
| from pathlib import Path | |
| from loguru import logger | |
| from hymm_sp.constants import PROMPT_TEMPLATE, PRECISION_TO_TYPE | |
| from hymm_sp.vae import load_vae | |
| from hymm_sp.modules import load_model | |
| from hymm_sp.text_encoder import TextEncoder | |
| import torch.distributed | |
| from hymm_sp.modules.parallel_states import ( | |
| nccl_info, | |
| ) | |
| from hymm_sp.modules.fp8_optimization import convert_fp8_linear | |
| class Inference(object): | |
| def __init__(self, | |
| args, | |
| vae, | |
| vae_kwargs, | |
| text_encoder, | |
| model, | |
| text_encoder_2=None, | |
| pipeline=None, | |
| cpu_offload=False, | |
| device=None, | |
| logger=None): | |
| self.vae = vae | |
| self.vae_kwargs = vae_kwargs | |
| self.text_encoder = text_encoder | |
| self.text_encoder_2 = text_encoder_2 | |
| self.model = model | |
| self.pipeline = pipeline | |
| self.cpu_offload = cpu_offload | |
| self.args = args | |
| self.device = device if device is not None else "cuda" if torch.cuda.is_available() else "cpu" | |
| if nccl_info.sp_size > 1: | |
| self.device = torch.device(f"cuda:{torch.distributed.get_rank()}") | |
| self.logger = logger | |
| def from_pretrained(cls, | |
| pretrained_model_path, | |
| args, | |
| device=None, | |
| **kwargs): | |
| """ | |
| Initialize the Inference pipeline. | |
| Args: | |
| pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints. | |
| device (int): The device for inference. Default is 0. | |
| logger (logging.Logger): The logger for the inference pipeline. Default is None. | |
| """ | |
| # ======================================================================== | |
| logger.info(f"Got text-to-video model root path: {pretrained_model_path}") | |
| # ======================== Get the args path ============================= | |
| # Set device and disable gradient | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| torch.set_grad_enabled(False) | |
| logger.info("Building model...") | |
| factor_kwargs = {'device': 'cpu' if args.cpu_offload else device, 'dtype': PRECISION_TO_TYPE[args.precision]} | |
| in_channels = args.latent_channels | |
| out_channels = args.latent_channels | |
| print("="*25, f"build model", "="*25) | |
| model = load_model( | |
| args, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| factor_kwargs=factor_kwargs | |
| ) | |
| if args.use_fp8: | |
| convert_fp8_linear(model, pretrained_model_path, original_dtype=PRECISION_TO_TYPE[args.precision]) | |
| if args.cpu_offload: | |
| print(f'='*20, f'load transformer to cpu') | |
| model = model.to('cpu') | |
| torch.cuda.empty_cache() | |
| else: | |
| model = model.to(device) | |
| model = Inference.load_state_dict(args, model, pretrained_model_path) | |
| model.eval() | |
| # ============================= Build extra models ======================== | |
| # VAE | |
| print("="*25, f"load vae", "="*25) | |
| vae, _, s_ratio, t_ratio = load_vae(args.vae, args.vae_precision, logger=logger, device='cpu' if args.cpu_offload else device) | |
| vae_kwargs = {'s_ratio': s_ratio, 't_ratio': t_ratio} | |
| # Text encoder | |
| if args.prompt_template_video is not None: | |
| crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0) | |
| else: | |
| crop_start = 0 | |
| max_length = args.text_len + crop_start | |
| # prompt_template_video | |
| prompt_template_video = PROMPT_TEMPLATE[args.prompt_template_video] if args.prompt_template_video is not None else None | |
| print("="*25, f"load llava", "="*25) | |
| text_encoder = TextEncoder(text_encoder_type = args.text_encoder, | |
| max_length = max_length, | |
| text_encoder_precision = args.text_encoder_precision, | |
| tokenizer_type = args.tokenizer, | |
| use_attention_mask = args.use_attention_mask, | |
| prompt_template_video = prompt_template_video, | |
| hidden_state_skip_layer = args.hidden_state_skip_layer, | |
| apply_final_norm = args.apply_final_norm, | |
| reproduce = args.reproduce, | |
| logger = logger, | |
| device = 'cpu' if args.cpu_offload else device , | |
| ) | |
| text_encoder_2 = None | |
| if args.text_encoder_2 is not None: | |
| text_encoder_2 = TextEncoder(text_encoder_type=args.text_encoder_2, | |
| max_length=args.text_len_2, | |
| text_encoder_precision=args.text_encoder_precision_2, | |
| tokenizer_type=args.tokenizer_2, | |
| use_attention_mask=args.use_attention_mask, | |
| reproduce=args.reproduce, | |
| logger=logger, | |
| device='cpu' if args.cpu_offload else device , # if not args.use_cpu_offload else 'cpu' | |
| ) | |
| return cls(args=args, | |
| vae=vae, | |
| vae_kwargs=vae_kwargs, | |
| text_encoder=text_encoder, | |
| model=model, | |
| text_encoder_2=text_encoder_2, | |
| device=device, | |
| logger=logger) | |
| def load_state_dict(args, model, ckpt_path): | |
| load_key = args.load_key | |
| ckpt_path = Path(ckpt_path) | |
| if ckpt_path.is_dir(): | |
| ckpt_path = next(ckpt_path.glob("*_model_states.pt")) | |
| state_dict = torch.load(ckpt_path, map_location=lambda storage, loc: storage) | |
| if load_key in state_dict: | |
| state_dict = state_dict[load_key] | |
| elif load_key == ".": | |
| pass | |
| else: | |
| raise KeyError(f"Key '{load_key}' not found in the checkpoint. Existed keys: {state_dict.keys()}") | |
| model.load_state_dict(state_dict, strict=False) | |
| return model | |
| def get_exp_dir_and_ckpt_id(self): | |
| if self.ckpt is None: | |
| raise ValueError("The checkpoint path is not provided.") | |
| ckpt = Path(self.ckpt) | |
| if ckpt.parents[1].name == "checkpoints": | |
| # It should be a standard checkpoint path. We use the parent directory as the default save directory. | |
| exp_dir = ckpt.parents[2] | |
| else: | |
| raise ValueError(f"We cannot infer the experiment directory from the checkpoint path: {ckpt}. " | |
| f"It seems that the checkpoint path is not standard. Please explicitly provide the " | |
| f"save path by --save-path.") | |
| return exp_dir, ckpt.parent.name | |
| def parse_size(size): | |
| if isinstance(size, int): | |
| size = [size] | |
| if not isinstance(size, (list, tuple)): | |
| raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
| if len(size) == 1: | |
| size = [size[0], size[0]] | |
| if len(size) != 2: | |
| raise ValueError(f"Size must be an integer or (height, width), got {size}.") | |
| return size | |