Spaces:
Running
on
Zero
Running
on
Zero
File size: 11,690 Bytes
cd0b70a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
# Created by Fabio Sarracino
import logging
import os
import tempfile
import torch
import numpy as np
import re
from typing import List, Optional
from .base_vibevoice import BaseVibeVoiceNode
# Setup logging
logger = logging.getLogger("VibeVoice")
class VibeVoiceSingleSpeakerNode(BaseVibeVoiceNode):
def __init__(self):
super().__init__()
# Register this instance for memory management
try:
from .free_memory_node import VibeVoiceFreeMemoryNode
VibeVoiceFreeMemoryNode.register_single_speaker(self)
except:
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {
"multiline": True,
"default": "Hello, this is a test of the VibeVoice text-to-speech system.",
"tooltip": "Text to convert to speech. Gets disabled when connected to another node.",
"forceInput": False,
"dynamicPrompts": True
}),
"model": (["VibeVoice-1.5B", "VibeVoice-Large", "VibeVoice-Large-Quant-4Bit"], {
"default": "VibeVoice-1.5B",
"tooltip": "Model to use. 1.5B is faster, Large has better quality, Quant-4Bit uses less VRAM (CUDA only)"
}),
"attention_type": (["auto", "eager", "sdpa", "flash_attention_2", "sage"], {
"default": "auto",
"tooltip": "Attention implementation. Auto selects the best available, eager is standard, sdpa is optimized PyTorch, flash_attention_2 requires compatible GPU, sage uses quantized attention for speedup (CUDA only)"
}),
"free_memory_after_generate": ("BOOLEAN", {"default": True, "tooltip": "Free model from memory after generation to save VRAM/RAM. Disable to keep model loaded for faster subsequent generations"}),
"diffusion_steps": ("INT", {"default": 20, "min": 5, "max": 100, "step": 1, "tooltip": "Number of denoising steps. More steps = better quality but slower. Default: 20"}),
"seed": ("INT", {"default": 42, "min": 0, "max": 2**32-1, "tooltip": "Random seed for generation. Default 42 is used in official examples"}),
"cfg_scale": ("FLOAT", {"default": 1.3, "min": 0.5, "max": 3.5, "step": 0.05, "tooltip": "Classifier-free guidance scale (official default: 1.3)"}),
"use_sampling": ("BOOLEAN", {"default": False, "tooltip": "Enable sampling mode. When False (default), uses deterministic generation like official examples"}),
},
"optional": {
"voice_to_clone": ("AUDIO", {"tooltip": "Optional: Reference voice to clone. If not provided, synthetic voice will be used."}),
"temperature": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
"top_p": ("FLOAT", {"default": 0.95, "min": 0.1, "max": 1.0, "step": 0.05, "tooltip": "Only used when sampling is enabled"}),
"max_words_per_chunk": ("INT", {"default": 250, "min": 100, "max": 500, "step": 50, "tooltip": "Maximum words per chunk for long texts. Lower values prevent speed issues but create more chunks."}),
}
}
RETURN_TYPES = ("AUDIO",)
RETURN_NAMES = ("audio",)
FUNCTION = "generate_speech"
CATEGORY = "VibeVoiceWrapper"
DESCRIPTION = "Generate speech from text using Microsoft VibeVoice with optional voice cloning"
def _prepare_voice_samples(self, speakers: list, voice_to_clone) -> List[np.ndarray]:
"""Prepare voice samples from input audio or create synthetic ones"""
if voice_to_clone is not None:
# Use the base class method to prepare audio
audio_np = self._prepare_audio_from_comfyui(voice_to_clone)
if audio_np is not None:
return [audio_np]
# Create synthetic voice samples for speakers
voice_samples = []
for i, speaker in enumerate(speakers):
voice_sample = self._create_synthetic_voice_sample(i)
voice_samples.append(voice_sample)
return voice_samples
def generate_speech(self, text: str = "", model: str = "VibeVoice-1.5B",
attention_type: str = "auto", free_memory_after_generate: bool = True,
diffusion_steps: int = 20, seed: int = 42, cfg_scale: float = 1.3,
use_sampling: bool = False, voice_to_clone=None,
temperature: float = 0.95, top_p: float = 0.95,
max_words_per_chunk: int = 250):
"""Generate speech from text using VibeVoice"""
try:
# Use text directly (it now serves as both manual input and connection input)
if text and text.strip():
final_text = text
else:
raise Exception("No text provided. Please enter text or connect from LoadTextFromFile node.")
# Get model mapping and load model with attention type
model_mapping = self._get_model_mapping()
model_path = model_mapping.get(model, model)
self.load_model(model, model_path, attention_type)
# For single speaker, we just use ["Speaker 1"]
speakers = ["Speaker 1"]
# Parse pause keywords from text
segments = self._parse_pause_keywords(final_text)
# Process segments
all_audio_segments = []
voice_samples = None # Will be created on first text segment
sample_rate = 24000 # VibeVoice uses 24kHz
for seg_idx, (seg_type, seg_content) in enumerate(segments):
if seg_type == 'pause':
# Generate silence for pause
duration_ms = seg_content
logger.info(f"Adding {duration_ms}ms pause")
silence_audio = self._generate_silence(duration_ms, sample_rate)
all_audio_segments.append(silence_audio)
elif seg_type == 'text':
# Process text segment (with chunking if needed)
word_count = len(seg_content.split())
if word_count > max_words_per_chunk:
# Split long text into chunks
logger.info(f"Text segment {seg_idx+1} has {word_count} words, splitting into chunks...")
text_chunks = self._split_text_into_chunks(seg_content, max_words_per_chunk)
for chunk_idx, chunk in enumerate(text_chunks):
logger.info(f"Processing chunk {chunk_idx+1}/{len(text_chunks)} of segment {seg_idx+1}...")
# Format chunk for VibeVoice
formatted_text = self._format_text_for_vibevoice(chunk, speakers)
# Create voice samples on first text segment
if voice_samples is None:
voice_samples = self._prepare_voice_samples(speakers, voice_to_clone)
# Generate audio for this chunk
chunk_audio = self._generate_with_vibevoice(
formatted_text, voice_samples, cfg_scale,
seed, # Use same seed for voice consistency
diffusion_steps, use_sampling, temperature, top_p
)
all_audio_segments.append(chunk_audio)
else:
# Process as single chunk
logger.info(f"Processing text segment {seg_idx+1} ({word_count} words)")
# Format text for VibeVoice
formatted_text = self._format_text_for_vibevoice(seg_content, speakers)
# Create voice samples on first text segment
if voice_samples is None:
voice_samples = self._prepare_voice_samples(speakers, voice_to_clone)
# Generate audio
segment_audio = self._generate_with_vibevoice(
formatted_text, voice_samples, cfg_scale, seed, diffusion_steps,
use_sampling, temperature, top_p
)
all_audio_segments.append(segment_audio)
# Concatenate all audio segments (including pauses)
if all_audio_segments:
logger.info(f"Concatenating {len(all_audio_segments)} audio segments (including pauses)...")
# Extract waveforms from all segments
waveforms = []
for audio_segment in all_audio_segments:
if isinstance(audio_segment, dict) and "waveform" in audio_segment:
waveforms.append(audio_segment["waveform"])
if waveforms:
# Filter out None values if any
valid_waveforms = [w for w in waveforms if w is not None]
if valid_waveforms:
# Concatenate along the time dimension (last dimension)
combined_waveform = torch.cat(valid_waveforms, dim=-1)
# Create final audio dict
audio_dict = {
"waveform": combined_waveform,
"sample_rate": sample_rate
}
logger.info(f"Successfully generated audio with {len(segments)} segments")
else:
raise Exception("No valid audio waveforms generated")
else:
raise Exception("Failed to extract waveforms from audio segments")
else:
raise Exception("No audio segments generated")
# Free memory if requested
if free_memory_after_generate:
self.free_memory()
return (audio_dict,)
except Exception as e:
# Check if this is an interruption by the user
import comfy.model_management as mm
if isinstance(e, mm.InterruptProcessingException):
# User interrupted - just log it and re-raise to stop the workflow
logger.info("Generation interrupted by user")
raise # Propagate the interruption to stop the workflow
else:
# Real error - show it
logger.error(f"Single speaker speech generation failed: {str(e)}")
raise Exception(f"Error generating speech: {str(e)}")
@classmethod
def IS_CHANGED(cls, text="", model="VibeVoice-1.5B", voice_to_clone=None, **kwargs):
"""Cache key for ComfyUI"""
voice_hash = hash(str(voice_to_clone)) if voice_to_clone else 0
return f"{hash(text)}_{model}_{voice_hash}_{kwargs.get('cfg_scale', 1.3)}_{kwargs.get('seed', 0)}" |