# Copyright 2025 aquif AI, Wan AI, Suno and HuggingFace teams import torch import torch.nn as nn from typing import Optional, Dict, Union, Tuple, List from pathlib import Path import json from dataclasses import asdict import warnings from transformers import PreTrainedModel, AutoConfig, AutoTokenizer, AutoModel class AquifDreamModelLoader: def __init__( self, model_name_or_path: Union[str, Path] = "aquif-ai/aquif-Dream-6B-Exp", cache_dir: Optional[str] = None, device: str = "cuda" if torch.cuda.is_available() else "cpu", load_in_8bit: bool = False, load_in_4bit: bool = False ): self.model_name_or_path = str(model_name_or_path) self.cache_dir = cache_dir self.device = device self.load_in_8bit = load_in_8bit self.load_in_4bit = load_in_4bit self.config = None self.model = None self.tokenizer = None self.component_models = {} def load_config(self) -> Dict: try: self.config = AutoConfig.from_pretrained( self.model_name_or_path, cache_dir=self.cache_dir ) return self.config except Exception as e: warnings.warn(f"Could not load config from {self.model_name_or_path}: {e}") return None def load_component_model( self, component_name: str, repo_id: str, load_weights: bool = False ) -> Optional[PreTrainedModel]: try: if component_name in self.component_models: return self.component_models[component_name] if load_weights: model = AutoModel.from_pretrained( repo_id, cache_dir=self.cache_dir, device_map=self.device, load_in_8bit=self.load_in_8bit, load_in_4bit=self.load_in_4bit ) else: model_config = AutoConfig.from_pretrained(repo_id, cache_dir=self.cache_dir) model = AutoModel.from_config(model_config) self.component_models[component_name] = model return model except Exception as e: warnings.warn(f"Could not load component {component_name}: {e}") return None def load_tokenizer(self, repo_id: str = "Wan-AI/Wan2.2-TI2V-5B") -> Optional: try: self.tokenizer = AutoTokenizer.from_pretrained( repo_id, cache_dir=self.cache_dir ) return self.tokenizer except Exception as e: warnings.warn(f"Could not load tokenizer: {e}") return None def load_full_model(self, load_weights: bool = False): self.load_config() components_to_load = [ ("video_backbone", "Wan-AI/Wan2.2-TI2V-5B"), ("vlm_backbone", "HuggingFaceTB/SmolVLM2-500M-Video-Instruct"), ("audio_backbone", "suno/bark-small") ] for component_name, repo_id in components_to_load: self.load_component_model(component_name, repo_id, load_weights=load_weights) self.load_tokenizer() if load_weights: try: from setup_aquif_dream import AquifDreamForVideoWithAudio self.model = AquifDreamForVideoWithAudio.from_pretrained( self.model_name_or_path, cache_dir=self.cache_dir ) if not self.load_in_8bit and not self.load_in_4bit: self.model = self.model.to(self.device) except Exception as e: warnings.warn(f"Could not load full model: {e}") return self.model def get_component_state_dict(self, component_name: str) -> Optional[Dict]: if component_name not in self.component_models: return None return self.component_models[component_name].state_dict() def get_config_dict(self) -> Dict: if self.config is None: self.load_config() return asdict(self.config) if hasattr(self.config, '__dataclass_fields__') else vars(self.config) class AquifDreamPipeline: def __init__( self, model_name_or_path: Union[str, Path] = "aquif-ai/aquif-Dream-6B-Exp", device: str = "cuda" if torch.cuda.is_available() else "cpu", dtype: torch.dtype = torch.float32, load_components: bool = True ): self.model_name_or_path = model_name_or_path self.device = device self.dtype = dtype self.loader = AquifDreamModelLoader( model_name_or_path=model_name_or_path, device=device ) if load_components: self.loader.load_full_model(load_weights=False) self.model = self.loader.model self.tokenizer = self.loader.tokenizer self.config = self.loader.config def prepare_text_input(self, text: str, max_length: int = 77) -> Dict[str, torch.Tensor]: if self.tokenizer is None: self.loader.load_tokenizer() encoded = self.tokenizer( text, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" ) return { "input_ids": encoded["input_ids"].to(self.device), "attention_mask": encoded["attention_mask"].to(self.device) } def generate_video( self, prompt: str, num_frames: int = 240, height: int = 512, width: int = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, return_intermediate: bool = False ) -> Dict[str, torch.Tensor]: text_inputs = self.prepare_text_input(prompt) outputs = self.model(input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"]) video_latents = outputs["generated_video_latents"] return { "video_latents": video_latents, "video_shape": (num_frames, height, width, 3), "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale } def generate_audio( self, prompt: str = None, text_prompt: str = None, sample_rate: int = 24000, duration: float = 10.0, temperature: float = 0.7 ) -> Dict[str, torch.Tensor]: audio_prompt = prompt or text_prompt if audio_prompt is None: raise ValueError("Either prompt or text_prompt must be provided") text_inputs = self.prepare_text_input(audio_prompt) outputs = self.model(input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"]) audio_waveform = outputs["generated_audio_waveform"] return { "audio_waveform": audio_waveform, "sample_rate": sample_rate, "duration": duration, "prompt": audio_prompt, "temperature": temperature } def generate_video_with_audio( self, prompt: str, num_frames: int = 240, height: int = 512, width: int = 512, sample_rate: int = 24000, num_inference_steps: int = 50, guidance_scale: float = 7.5, temperature: float = 0.7 ) -> Dict[str, Union[torch.Tensor, str, float]]: text_inputs = self.prepare_text_input(prompt) outputs = self.model(input_ids=text_inputs["input_ids"], attention_mask=text_inputs["attention_mask"]) video_latents = outputs["generated_video_latents"] audio_waveform = outputs["generated_audio_waveform"] captions = outputs["generated_captions"] sync_map = outputs["sync_information"] sync_confidence = outputs["sync_confidence"] return { "video": { "latents": video_latents, "shape": (num_frames, height, width, 3), "fps": 24, "duration": num_frames / 24.0 }, "audio": { "waveform": audio_waveform, "sample_rate": sample_rate, "duration": len(audio_waveform) / sample_rate if isinstance(audio_waveform, torch.Tensor) else 0 }, "captions": captions, "synchronization": { "sync_map": sync_map, "sync_confidence": sync_confidence, "alignment_method": "cross_attention_temporal" }, "metadata": { "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "temperature": temperature } } def decode_video_latents(self, latents: torch.Tensor) -> torch.Tensor: if self.model is None or not hasattr(self.model, 'wan_vae'): warnings.warn("VAE decoder not available in current model setup") return latents with torch.no_grad(): video_frames = self.model.wan_vae.decode(latents) return video_frames def encode_video_frames(self, frames: torch.Tensor) -> torch.Tensor: if self.model is None or not hasattr(self.model, 'wan_vae'): warnings.warn("VAE encoder not available in current model setup") return frames with torch.no_grad(): latents = self.model.wan_vae.encode(frames) return latents def extract_visual_features(self, video_frames: torch.Tensor) -> torch.Tensor: if self.model is None or not hasattr(self.model, 'vlm_vision_encoder'): warnings.warn("Vision encoder not available in current model setup") return None with torch.no_grad(): visual_features = self.model.vlm_vision_encoder(video_frames) return visual_features def generate_captions(self, video_frames: torch.Tensor) -> torch.Tensor: if self.model is None or not hasattr(self.model, 'vlm_decoder'): warnings.warn("Vision-language decoder not available") return None visual_features = self.extract_visual_features(video_frames) with torch.no_grad(): caption_logits = self.model.vlm_decoder( input_ids=torch.arange(64).unsqueeze(0).to(self.device), vision_embeddings=visual_features ) return caption_logits def get_model_info(self) -> Dict: if self.config is None: self.loader.load_config() config_dict = self.loader.get_config_dict() return { "model_name": "aquif-Dream-6B-Exp", "model_type": config_dict.get("model_type", "aquif_dream"), "total_parameters": 5_920_000_000, "video_specs": { "resolution": config_dict.get("video_resolution", (512, 512)), "fps": config_dict.get("video_fps", 24), "duration": config_dict.get("video_duration", 10.0), "total_frames": config_dict.get("video_total_frames", 240) }, "audio_specs": { "sample_rate": config_dict.get("audio_sample_rate", 24000), "channels": config_dict.get("audio_channels", 1), "max_duration": config_dict.get("audio_max_duration", 10.0) }, "components": { "video_backbone": {"name": "Wan2.2-TI2V-5B", "parameters": 5_000_000_000}, "vlm_backbone": {"name": "SmolVLM2-500M-Video-Instruct", "parameters": 500_000_000}, "audio_backbone": {"name": "Suno Bark Small", "parameters": 420_000_000} }, "device": str(self.device), "dtype": str(self.dtype), "unified_embedding_dim": config_dict.get("unified_embedding_dim", 768) } def to(self, device: str): self.device = device if self.model is not None: self.model = self.model.to(device) return self def to_dtype(self, dtype: torch.dtype): self.dtype = dtype if self.model is not None: self.model = self.model.to(dtype) return self @classmethod def from_pretrained( cls, model_name_or_path: Union[str, Path] = "aquif-ai/aquif-Dream-6B-Exp", device: str = "cuda" if torch.cuda.is_available() else "cpu", **kwargs ): return cls(model_name_or_path=model_name_or_path, device=device, **kwargs)