File size: 4,372 Bytes
cd0b70a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Created by Fabio Sarracino
# Node to free VibeVoice model memory

import logging
import torch
import gc
from typing import Any

# Setup logging
logger = logging.getLogger("VibeVoice")

class VibeVoiceFreeMemoryNode:
    """Node to explicitly free VibeVoice model memory"""
    
    # Class variables to store node instances
    _single_speaker_instances = []
    _multi_speaker_instances = []
    
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "required": {
                "audio": ("AUDIO", {"tooltip": "Audio input that triggers memory cleanup and gets passed through"}),
            }
        }

    RETURN_TYPES = ("AUDIO",)
    RETURN_NAMES = ("audio",)
    FUNCTION = "free_vibevoice_memory"
    CATEGORY = "VibeVoiceWrapper"
    DESCRIPTION = "Free all loaded VibeVoice models from memory when audio passes through"
    
    @classmethod
    def register_single_speaker(cls, node_instance):
        """Register a single speaker node instance"""
        if node_instance not in cls._single_speaker_instances:
            cls._single_speaker_instances.append(node_instance)
    
    @classmethod
    def register_multi_speaker(cls, node_instance):
        """Register a multi speaker node instance"""
        if node_instance not in cls._multi_speaker_instances:
            cls._multi_speaker_instances.append(node_instance)
    
    def free_vibevoice_memory(self, audio):
        """Free memory from all VibeVoice nodes and pass through the audio"""
        
        try:
            freed_count = 0
            
            # Try to access and free memory from globally cached instances
            # ComfyUI might cache node instances
            try:
                import sys
                from .base_vibevoice import BaseVibeVoiceNode
                
                # Search in all modules for BaseVibeVoiceNode instances
                for module_name, module in sys.modules.items():
                    if module and 'vibevoice' in module_name.lower():
                        for attr_name in dir(module):
                            if not attr_name.startswith('_'):
                                try:
                                    attr = getattr(module, attr_name)
                                    if isinstance(attr, type) and issubclass(attr, BaseVibeVoiceNode):
                                        # Check if the class has any cached instances
                                        for instance_attr in dir(attr):
                                            instance = getattr(attr, instance_attr)
                                            if isinstance(instance, BaseVibeVoiceNode) and hasattr(instance, 'free_memory'):
                                                instance.free_memory()
                                                freed_count += 1
                                except:
                                    pass
            except:
                pass
            
            # Free from registered single speaker instances
            for node in self._single_speaker_instances:
                if hasattr(node, 'free_memory'):
                    node.free_memory()
                    freed_count += 1
            
            # Free from registered multi speaker instances  
            for node in self._multi_speaker_instances:
                if hasattr(node, 'free_memory'):
                    node.free_memory()
                    freed_count += 1
            
            # Force garbage collection
            gc.collect()
            
            # Clear CUDA cache if available
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
                logger.info(f"Freed VibeVoice memory from {freed_count} nodes and cleared CUDA cache")
            else:
                logger.info(f"Freed VibeVoice memory from {freed_count} nodes")
            
            # Pass through the audio unchanged
            return (audio,)
                
        except Exception as e:
            logger.error(f"Error freeing VibeVoice memory: {str(e)}")
            # Still pass through audio even if error occurs
            return (audio,)
    
    @classmethod
    def IS_CHANGED(cls, **kwargs):
        """Always execute this node"""
        return float("nan")  # Forces re-execution every time