Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		alex
		
	commited on
		
		
					Commit 
							
							·
						
						ac6279d
	
1
								Parent(s):
							
							ad86dfe
								
higgs text-to-speech added
Browse files- app.py +46 -1
 - examples/audios/config.json +10 -0
 - higgs_audio/__init__.py +1 -0
 - higgs_audio/audio_processing/LICENSE +51 -0
 - higgs_audio/audio_processing/descriptaudiocodec/__init__.py +0 -0
 - higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py +286 -0
 - higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py +365 -0
 - higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py +33 -0
 - higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py +251 -0
 - higgs_audio/audio_processing/higgs_audio_tokenizer.py +341 -0
 - higgs_audio/audio_processing/quantization/__init__.py +8 -0
 - higgs_audio/audio_processing/quantization/ac.py +301 -0
 - higgs_audio/audio_processing/quantization/core_vq.py +360 -0
 - higgs_audio/audio_processing/quantization/core_vq_lsx_version.py +431 -0
 - higgs_audio/audio_processing/quantization/ddp_utils.py +197 -0
 - higgs_audio/audio_processing/quantization/distrib.py +123 -0
 - higgs_audio/audio_processing/quantization/vq.py +116 -0
 - higgs_audio/audio_processing/semantic_module.py +310 -0
 - higgs_audio/constants.py +3 -0
 - higgs_audio/data_collator/__init__.py +0 -0
 - higgs_audio/data_collator/higgs_audio_collator.py +583 -0
 - higgs_audio/data_types.py +38 -0
 - higgs_audio/dataset/__init__.py +0 -0
 - higgs_audio/dataset/chatml_dataset.py +554 -0
 - higgs_audio/model/__init__.py +9 -0
 - higgs_audio/model/audio_head.py +139 -0
 - higgs_audio/model/common.py +27 -0
 - higgs_audio/model/configuration_higgs_audio.py +235 -0
 - higgs_audio/model/cuda_graph_runner.py +129 -0
 - higgs_audio/model/custom_modules.py +155 -0
 - higgs_audio/model/modeling_higgs_audio.py +0 -0
 - higgs_audio/model/utils.py +778 -0
 - higgs_audio/serve/serve_engine.py +474 -0
 - higgs_audio/serve/utils.py +254 -0
 - higgs_audio_utils.py +290 -0
 - requirements.txt +17 -2
 
    	
        app.py
    CHANGED
    
    | 
         @@ -57,6 +57,8 @@ from tqdm import tqdm 
     | 
|
| 57 | 
         
             
            from functools import partial
         
     | 
| 58 | 
         
             
            from omegaconf import OmegaConf
         
     | 
| 59 | 
         
             
            from argparse import Namespace
         
     | 
| 
         | 
|
| 
         | 
|
| 60 | 
         | 
| 61 | 
         
             
            # load the one true config you dumped
         
     | 
| 62 | 
         
             
            _args_cfg = OmegaConf.load("args_config.yaml")
         
     | 
| 
         @@ -78,10 +80,46 @@ from transformers import Wav2Vec2FeatureExtractor 
     | 
|
| 78 | 
         
             
            import torchvision.transforms as transforms
         
     | 
| 79 | 
         
             
            import torch.nn.functional as F
         
     | 
| 80 | 
         
             
            from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
         
     | 
| 
         | 
|
| 81 | 
         | 
| 82 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 83 | 
         
             
            os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
         
     | 
| 84 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 85 | 
         
             
            def tensor_to_pil(tensor):
         
     | 
| 86 | 
         
             
                """
         
     | 
| 87 | 
         
             
                Args:
         
     | 
| 
         @@ -726,7 +764,7 @@ with gr.Blocks(css=css) as demo: 
     | 
|
| 726 | 
         
             
                        with gr.Column():
         
     | 
| 727 | 
         | 
| 728 | 
         
             
                            image_input = gr.Image(label="Reference Image", type="filepath", height=512)
         
     | 
| 729 | 
         
            -
                            audio_input =  
     | 
| 730 | 
         | 
| 731 | 
         | 
| 732 | 
         
             
                        with gr.Column():
         
     | 
| 
         @@ -812,6 +850,13 @@ with gr.Blocks(css=css) as demo: 
     | 
|
| 812 | 
         
             
                    inputs=[image_input, audio_input, text_input, num_steps, session_state],
         
     | 
| 813 | 
         
             
                    outputs=[output_video]
         
     | 
| 814 | 
         
             
                )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 815 | 
         
             
                image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input])
         
     | 
| 816 | 
         
             
                image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
         
     | 
| 817 | 
         
             
                audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
         
     | 
| 
         | 
|
| 57 | 
         
             
            from functools import partial
         
     | 
| 58 | 
         
             
            from omegaconf import OmegaConf
         
     | 
| 59 | 
         
             
            from argparse import Namespace
         
     | 
| 60 | 
         
            +
            from gradio_extendedaudio import ExtendedAudio
         
     | 
| 61 | 
         
            +
            import torchaudio
         
     | 
| 62 | 
         | 
| 63 | 
         
             
            # load the one true config you dumped
         
     | 
| 64 | 
         
             
            _args_cfg = OmegaConf.load("args_config.yaml")
         
     | 
| 
         | 
|
| 80 | 
         
             
            import torchvision.transforms as transforms
         
     | 
| 81 | 
         
             
            import torch.nn.functional as F
         
     | 
| 82 | 
         
             
            from OmniAvatar.utils.audio_preprocess import add_silence_to_audio_ffmpeg
         
     | 
| 83 | 
         
            +
            from higgs_audio_utils import text_to_speech, initialize_engine
         
     | 
| 84 | 
         | 
| 85 | 
         | 
| 86 | 
         
            +
            DEFAULT_TTS_MODEL_PATH = "bosonai/higgs-audio-v2-generation-3B-base"
         
     | 
| 87 | 
         
            +
            DEFAULT_AUDIO_TOKENIZER_PATH = "bosonai/higgs-audio-v2-tokenizer"
         
     | 
| 88 | 
         
            +
            engine = initialize_engine(DEFAULT_TTS_MODEL_PATH, DEFAULT_AUDIO_TOKENIZER_PATH)
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
             
            os.environ["PROCESSED_RESULTS"] = f"{os.getcwd()}/proprocess_results"
         
     | 
| 91 | 
         | 
| 92 | 
         
            +
            @spaces.GPU
         
     | 
| 93 | 
         
            +
            def tts_from_text(text, voice_choice):
         
     | 
| 94 | 
         
            +
                _, output = text_to_speech(engine, text, voice_preset=voice_choice)
         
     | 
| 95 | 
         
            +
                return output
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            def speak_to_me(session_id, evt: gr.EventData):
         
     | 
| 98 | 
         
            +
                detail = getattr(evt, "data", None) or getattr(evt, "_data", {}) or {}
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                current_text = detail.get("text", "")
         
     | 
| 101 | 
         
            +
                current_choice = detail.get("choice", "")
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                print(current_choice)
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                output = tts_from_text(current_text, current_choice)
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                if session_id is None:
         
     | 
| 108 | 
         
            +
                    session_id = uuid.uuid4().hex
         
     | 
| 109 | 
         
            +
                    
         
     | 
| 110 | 
         
            +
                output_dir = os.path.join(os.environ["PROCESSED_RESULTS"], session_id)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                tts_dir = output_dir + '/tts'
         
     | 
| 113 | 
         
            +
                os.makedirs(tts_dir, exist_ok=True)
         
     | 
| 114 | 
         
            +
                speech_to_text_path = os.path.join(tts_dir, f"speech_to_text.wav")
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                sampling_rate = output[0]
         
     | 
| 117 | 
         
            +
                audio_data = output[1]
         
     | 
| 118 | 
         
            +
                
         
     | 
| 119 | 
         
            +
                torchaudio.save(speech_to_text_path, torch.from_numpy(audio_data)[None, :], output[0])
         
     | 
| 120 | 
         
            +
                
         
     | 
| 121 | 
         
            +
                return speech_to_text_path
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
             
            def tensor_to_pil(tensor):
         
     | 
| 124 | 
         
             
                """
         
     | 
| 125 | 
         
             
                Args:
         
     | 
| 
         | 
|
| 764 | 
         
             
                        with gr.Column():
         
     | 
| 765 | 
         | 
| 766 | 
         
             
                            image_input = gr.Image(label="Reference Image", type="filepath", height=512)
         
     | 
| 767 | 
         
            +
                            audio_input = ExtendedAudio(label="Input Audio", type="filepath", value="examples/audios/berry.wav", options=["Cleo", "Cleon"])
         
     | 
| 768 | 
         | 
| 769 | 
         | 
| 770 | 
         
             
                        with gr.Column():
         
     | 
| 
         | 
|
| 850 | 
         
             
                    inputs=[image_input, audio_input, text_input, num_steps, session_state],
         
     | 
| 851 | 
         
             
                    outputs=[output_video]
         
     | 
| 852 | 
         
             
                )
         
     | 
| 853 | 
         
            +
             
     | 
| 854 | 
         
            +
                audio_input.generate(
         
     | 
| 855 | 
         
            +
                    fn=speak_to_me, 
         
     | 
| 856 | 
         
            +
                    inputs=[session_state],
         
     | 
| 857 | 
         
            +
                    outputs=[audio_input]
         
     | 
| 858 | 
         
            +
                )
         
     | 
| 859 | 
         
            +
             
     | 
| 860 | 
         
             
                image_input.upload(fn=preprocess_img, inputs=[image_input, session_state], outputs=[image_input])
         
     | 
| 861 | 
         
             
                image_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
         
     | 
| 862 | 
         
             
                audio_input.change(fn=update_generate_button, inputs=[image_input, audio_input, text_input, num_steps], outputs=[time_required])
         
     | 
    	
        examples/audios/config.json
    ADDED
    
    | 
         @@ -0,0 +1,10 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            {
         
     | 
| 2 | 
         
            +
              "Cleon": {
         
     | 
| 3 | 
         
            +
                    "transcript": "Strive not to be a success, but rather to be of value.",
         
     | 
| 4 | 
         
            +
                    "audio_file": "story_telling_voice_4.wav"
         
     | 
| 5 | 
         
            +
                },
         
     | 
| 6 | 
         
            +
              "Cleo": {
         
     | 
| 7 | 
         
            +
                    "transcript": "Strive not to be a success, but rather to be of value.",
         
     | 
| 8 | 
         
            +
                    "audio_file": "story_telling_voice_2.wav"
         
     | 
| 9 | 
         
            +
                }
         
     | 
| 10 | 
         
            +
            }
         
     | 
    	
        higgs_audio/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from .model import HiggsAudioConfig, HiggsAudioModel
         
     | 
    	
        higgs_audio/audio_processing/LICENSE
    ADDED
    
    | 
         @@ -0,0 +1,51 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            Third-Party License Attribution for Audio Processing Module
         
     | 
| 2 | 
         
            +
            ===========================================================
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            This directory contains code derived from multiple open-source projects. 
         
     | 
| 5 | 
         
            +
            The following sections detail the licenses and attributions for third-party code.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            ## XCodec Repository
         
     | 
| 8 | 
         
            +
            The code in this directory is derived from:
         
     | 
| 9 | 
         
            +
            https://github.com/zhenye234/xcodec
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            ## Individual File Attributions
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            ### Quantization Module (quantization/)
         
     | 
| 14 | 
         
            +
            - Several files contain code derived from Meta Platforms, Inc. and the vector-quantize-pytorch repository
         
     | 
| 15 | 
         
            +
            - Individual files contain their own license headers where applicable
         
     | 
| 16 | 
         
            +
            - The vector-quantize-pytorch portions are licensed under the MIT License
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            ## License Terms
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            ### MIT License (for applicable portions)
         
     | 
| 21 | 
         
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 22 | 
         
            +
            of this software and associated documentation files (the "Software"), to deal
         
     | 
| 23 | 
         
            +
            in the Software without restriction, including without limitation the rights
         
     | 
| 24 | 
         
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 25 | 
         
            +
            copies of the Software, and to permit persons to whom the Software is
         
     | 
| 26 | 
         
            +
            furnished to do so, subject to the following conditions:
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            The above copyright notice and this permission notice shall be included in all
         
     | 
| 29 | 
         
            +
            copies or substantial portions of the Software.
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 32 | 
         
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 33 | 
         
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 34 | 
         
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 35 | 
         
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 36 | 
         
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 37 | 
         
            +
            SOFTWARE.
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            ## Attribution Requirements
         
     | 
| 40 | 
         
            +
            When using this code, please ensure proper attribution to:
         
     | 
| 41 | 
         
            +
            1. The original xcodec repository: https://github.com/zhenye234/xcodec
         
     | 
| 42 | 
         
            +
            2. Any other repositories mentioned in individual file headers
         
     | 
| 43 | 
         
            +
            3. This derivative work and its modifications
         
     | 
| 44 | 
         
            +
             
     | 
| 45 | 
         
            +
            ## Disclaimer
         
     | 
| 46 | 
         
            +
            This directory contains modified versions of the original code. Please refer to
         
     | 
| 47 | 
         
            +
            the original repositories for the canonical implementations and their specific
         
     | 
| 48 | 
         
            +
            license terms.
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            For any questions about licensing or attribution, please check the individual
         
     | 
| 51 | 
         
            +
            file headers and the original source repositories. 
         
     | 
    	
        higgs_audio/audio_processing/descriptaudiocodec/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        higgs_audio/audio_processing/descriptaudiocodec/dac/model/base.py
    ADDED
    
    | 
         @@ -0,0 +1,286 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 3 | 
         
            +
            from pathlib import Path
         
     | 
| 4 | 
         
            +
            from typing import Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import tqdm
         
     | 
| 9 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 10 | 
         
            +
            from torch import nn
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            SUPPORTED_VERSIONS = ["1.0.0"]
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            @dataclass
         
     | 
| 16 | 
         
            +
            class DACFile:
         
     | 
| 17 | 
         
            +
                codes: torch.Tensor
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
                # Metadata
         
     | 
| 20 | 
         
            +
                chunk_length: int
         
     | 
| 21 | 
         
            +
                original_length: int
         
     | 
| 22 | 
         
            +
                input_db: float
         
     | 
| 23 | 
         
            +
                channels: int
         
     | 
| 24 | 
         
            +
                sample_rate: int
         
     | 
| 25 | 
         
            +
                padding: bool
         
     | 
| 26 | 
         
            +
                dac_version: str
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                def save(self, path):
         
     | 
| 29 | 
         
            +
                    artifacts = {
         
     | 
| 30 | 
         
            +
                        "codes": self.codes.numpy().astype(np.uint16),
         
     | 
| 31 | 
         
            +
                        "metadata": {
         
     | 
| 32 | 
         
            +
                            "input_db": self.input_db.numpy().astype(np.float32),
         
     | 
| 33 | 
         
            +
                            "original_length": self.original_length,
         
     | 
| 34 | 
         
            +
                            "sample_rate": self.sample_rate,
         
     | 
| 35 | 
         
            +
                            "chunk_length": self.chunk_length,
         
     | 
| 36 | 
         
            +
                            "channels": self.channels,
         
     | 
| 37 | 
         
            +
                            "padding": self.padding,
         
     | 
| 38 | 
         
            +
                            "dac_version": SUPPORTED_VERSIONS[-1],
         
     | 
| 39 | 
         
            +
                        },
         
     | 
| 40 | 
         
            +
                    }
         
     | 
| 41 | 
         
            +
                    path = Path(path).with_suffix(".dac")
         
     | 
| 42 | 
         
            +
                    with open(path, "wb") as f:
         
     | 
| 43 | 
         
            +
                        np.save(f, artifacts)
         
     | 
| 44 | 
         
            +
                    return path
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                @classmethod
         
     | 
| 47 | 
         
            +
                def load(cls, path):
         
     | 
| 48 | 
         
            +
                    artifacts = np.load(path, allow_pickle=True)[()]
         
     | 
| 49 | 
         
            +
                    codes = torch.from_numpy(artifacts["codes"].astype(int))
         
     | 
| 50 | 
         
            +
                    if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS:
         
     | 
| 51 | 
         
            +
                        raise RuntimeError(f"Given file {path} can't be loaded with this version of descript-audio-codec.")
         
     | 
| 52 | 
         
            +
                    return cls(codes=codes, **artifacts["metadata"])
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
            class CodecMixin:
         
     | 
| 56 | 
         
            +
                @property
         
     | 
| 57 | 
         
            +
                def padding(self):
         
     | 
| 58 | 
         
            +
                    if not hasattr(self, "_padding"):
         
     | 
| 59 | 
         
            +
                        self._padding = True
         
     | 
| 60 | 
         
            +
                    return self._padding
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                @padding.setter
         
     | 
| 63 | 
         
            +
                def padding(self, value):
         
     | 
| 64 | 
         
            +
                    assert isinstance(value, bool)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    layers = [l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d))]
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    for layer in layers:
         
     | 
| 69 | 
         
            +
                        if value:
         
     | 
| 70 | 
         
            +
                            if hasattr(layer, "original_padding"):
         
     | 
| 71 | 
         
            +
                                layer.padding = layer.original_padding
         
     | 
| 72 | 
         
            +
                        else:
         
     | 
| 73 | 
         
            +
                            layer.original_padding = layer.padding
         
     | 
| 74 | 
         
            +
                            layer.padding = tuple(0 for _ in range(len(layer.padding)))
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    self._padding = value
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def get_delay(self):
         
     | 
| 79 | 
         
            +
                    # Any number works here, delay is invariant to input length
         
     | 
| 80 | 
         
            +
                    l_out = self.get_output_length(0)
         
     | 
| 81 | 
         
            +
                    L = l_out
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                    layers = []
         
     | 
| 84 | 
         
            +
                    for layer in self.modules():
         
     | 
| 85 | 
         
            +
                        if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
         
     | 
| 86 | 
         
            +
                            layers.append(layer)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                    for layer in reversed(layers):
         
     | 
| 89 | 
         
            +
                        d = layer.dilation[0]
         
     | 
| 90 | 
         
            +
                        k = layer.kernel_size[0]
         
     | 
| 91 | 
         
            +
                        s = layer.stride[0]
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                        if isinstance(layer, nn.ConvTranspose1d):
         
     | 
| 94 | 
         
            +
                            L = ((L - d * (k - 1) - 1) / s) + 1
         
     | 
| 95 | 
         
            +
                        elif isinstance(layer, nn.Conv1d):
         
     | 
| 96 | 
         
            +
                            L = (L - 1) * s + d * (k - 1) + 1
         
     | 
| 97 | 
         
            +
             
     | 
| 98 | 
         
            +
                        L = math.ceil(L)
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    l_in = L
         
     | 
| 101 | 
         
            +
             
     | 
| 102 | 
         
            +
                    return (l_in - l_out) // 2
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def get_output_length(self, input_length):
         
     | 
| 105 | 
         
            +
                    L = input_length
         
     | 
| 106 | 
         
            +
                    # Calculate output length
         
     | 
| 107 | 
         
            +
                    for layer in self.modules():
         
     | 
| 108 | 
         
            +
                        if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)):
         
     | 
| 109 | 
         
            +
                            d = layer.dilation[0]
         
     | 
| 110 | 
         
            +
                            k = layer.kernel_size[0]
         
     | 
| 111 | 
         
            +
                            s = layer.stride[0]
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                            if isinstance(layer, nn.Conv1d):
         
     | 
| 114 | 
         
            +
                                L = ((L - d * (k - 1) - 1) / s) + 1
         
     | 
| 115 | 
         
            +
                            elif isinstance(layer, nn.ConvTranspose1d):
         
     | 
| 116 | 
         
            +
                                L = (L - 1) * s + d * (k - 1) + 1
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                            L = math.floor(L)
         
     | 
| 119 | 
         
            +
                    return L
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                @torch.no_grad()
         
     | 
| 122 | 
         
            +
                def compress(
         
     | 
| 123 | 
         
            +
                    self,
         
     | 
| 124 | 
         
            +
                    audio_path_or_signal: Union[str, Path, AudioSignal],
         
     | 
| 125 | 
         
            +
                    win_duration: float = 1.0,
         
     | 
| 126 | 
         
            +
                    verbose: bool = False,
         
     | 
| 127 | 
         
            +
                    normalize_db: float = -16,
         
     | 
| 128 | 
         
            +
                    n_quantizers: int = None,
         
     | 
| 129 | 
         
            +
                ) -> DACFile:
         
     | 
| 130 | 
         
            +
                    """Processes an audio signal from a file or AudioSignal object into
         
     | 
| 131 | 
         
            +
                    discrete codes. This function processes the signal in short windows,
         
     | 
| 132 | 
         
            +
                    using constant GPU memory.
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    Parameters
         
     | 
| 135 | 
         
            +
                    ----------
         
     | 
| 136 | 
         
            +
                    audio_path_or_signal : Union[str, Path, AudioSignal]
         
     | 
| 137 | 
         
            +
                        audio signal to reconstruct
         
     | 
| 138 | 
         
            +
                    win_duration : float, optional
         
     | 
| 139 | 
         
            +
                        window duration in seconds, by default 5.0
         
     | 
| 140 | 
         
            +
                    verbose : bool, optional
         
     | 
| 141 | 
         
            +
                        by default False
         
     | 
| 142 | 
         
            +
                    normalize_db : float, optional
         
     | 
| 143 | 
         
            +
                        normalize db, by default -16
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                    Returns
         
     | 
| 146 | 
         
            +
                    -------
         
     | 
| 147 | 
         
            +
                    DACFile
         
     | 
| 148 | 
         
            +
                        Object containing compressed codes and metadata
         
     | 
| 149 | 
         
            +
                        required for decompression
         
     | 
| 150 | 
         
            +
                    """
         
     | 
| 151 | 
         
            +
                    audio_signal = audio_path_or_signal
         
     | 
| 152 | 
         
            +
                    if isinstance(audio_signal, (str, Path)):
         
     | 
| 153 | 
         
            +
                        audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal))
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    self.eval()
         
     | 
| 156 | 
         
            +
                    original_padding = self.padding
         
     | 
| 157 | 
         
            +
                    original_device = audio_signal.device
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    audio_signal = audio_signal.clone()
         
     | 
| 160 | 
         
            +
                    original_sr = audio_signal.sample_rate
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    resample_fn = audio_signal.resample
         
     | 
| 163 | 
         
            +
                    loudness_fn = audio_signal.loudness
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    # If audio is > 10 minutes long, use the ffmpeg versions
         
     | 
| 166 | 
         
            +
                    if audio_signal.signal_duration >= 10 * 60 * 60:
         
     | 
| 167 | 
         
            +
                        resample_fn = audio_signal.ffmpeg_resample
         
     | 
| 168 | 
         
            +
                        loudness_fn = audio_signal.ffmpeg_loudness
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                    original_length = audio_signal.signal_length
         
     | 
| 171 | 
         
            +
                    resample_fn(self.sample_rate)
         
     | 
| 172 | 
         
            +
                    input_db = loudness_fn()
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    if normalize_db is not None:
         
     | 
| 175 | 
         
            +
                        audio_signal.normalize(normalize_db)
         
     | 
| 176 | 
         
            +
                    audio_signal.ensure_max_of_audio()
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    nb, nac, nt = audio_signal.audio_data.shape
         
     | 
| 179 | 
         
            +
                    audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt)
         
     | 
| 180 | 
         
            +
                    win_duration = audio_signal.signal_duration if win_duration is None else win_duration
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    if audio_signal.signal_duration <= win_duration:
         
     | 
| 183 | 
         
            +
                        # Unchunked compression (used if signal length < win duration)
         
     | 
| 184 | 
         
            +
                        self.padding = True
         
     | 
| 185 | 
         
            +
                        n_samples = nt
         
     | 
| 186 | 
         
            +
                        hop = nt
         
     | 
| 187 | 
         
            +
                    else:
         
     | 
| 188 | 
         
            +
                        # Chunked inference
         
     | 
| 189 | 
         
            +
                        self.padding = False
         
     | 
| 190 | 
         
            +
                        # Zero-pad signal on either side by the delay
         
     | 
| 191 | 
         
            +
                        audio_signal.zero_pad(self.delay, self.delay)
         
     | 
| 192 | 
         
            +
                        n_samples = int(win_duration * self.sample_rate)
         
     | 
| 193 | 
         
            +
                        # Round n_samples to nearest hop length multiple
         
     | 
| 194 | 
         
            +
                        n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length)
         
     | 
| 195 | 
         
            +
                        hop = self.get_output_length(n_samples)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                    codes = []
         
     | 
| 198 | 
         
            +
                    range_fn = range if not verbose else tqdm.trange
         
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
                    for i in range_fn(0, nt, hop):
         
     | 
| 201 | 
         
            +
                        x = audio_signal[..., i : i + n_samples]
         
     | 
| 202 | 
         
            +
                        x = x.zero_pad(0, max(0, n_samples - x.shape[-1]))
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                        audio_data = x.audio_data.to(self.device)
         
     | 
| 205 | 
         
            +
                        audio_data = self.preprocess(audio_data, self.sample_rate)
         
     | 
| 206 | 
         
            +
                        _, c, _, _, _ = self.encode(audio_data, n_quantizers)
         
     | 
| 207 | 
         
            +
                        codes.append(c.to(original_device))
         
     | 
| 208 | 
         
            +
                        chunk_length = c.shape[-1]
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                    codes = torch.cat(codes, dim=-1)
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    dac_file = DACFile(
         
     | 
| 213 | 
         
            +
                        codes=codes,
         
     | 
| 214 | 
         
            +
                        chunk_length=chunk_length,
         
     | 
| 215 | 
         
            +
                        original_length=original_length,
         
     | 
| 216 | 
         
            +
                        input_db=input_db,
         
     | 
| 217 | 
         
            +
                        channels=nac,
         
     | 
| 218 | 
         
            +
                        sample_rate=original_sr,
         
     | 
| 219 | 
         
            +
                        padding=self.padding,
         
     | 
| 220 | 
         
            +
                        dac_version=SUPPORTED_VERSIONS[-1],
         
     | 
| 221 | 
         
            +
                    )
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    if n_quantizers is not None:
         
     | 
| 224 | 
         
            +
                        codes = codes[:, :n_quantizers, :]
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                    self.padding = original_padding
         
     | 
| 227 | 
         
            +
                    return dac_file
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                @torch.no_grad()
         
     | 
| 230 | 
         
            +
                def decompress(
         
     | 
| 231 | 
         
            +
                    self,
         
     | 
| 232 | 
         
            +
                    obj: Union[str, Path, DACFile],
         
     | 
| 233 | 
         
            +
                    verbose: bool = False,
         
     | 
| 234 | 
         
            +
                ) -> AudioSignal:
         
     | 
| 235 | 
         
            +
                    """Reconstruct audio from a given .dac file
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                    Parameters
         
     | 
| 238 | 
         
            +
                    ----------
         
     | 
| 239 | 
         
            +
                    obj : Union[str, Path, DACFile]
         
     | 
| 240 | 
         
            +
                        .dac file location or corresponding DACFile object.
         
     | 
| 241 | 
         
            +
                    verbose : bool, optional
         
     | 
| 242 | 
         
            +
                        Prints progress if True, by default False
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    Returns
         
     | 
| 245 | 
         
            +
                    -------
         
     | 
| 246 | 
         
            +
                    AudioSignal
         
     | 
| 247 | 
         
            +
                        Object with the reconstructed audio
         
     | 
| 248 | 
         
            +
                    """
         
     | 
| 249 | 
         
            +
                    self.eval()
         
     | 
| 250 | 
         
            +
                    if isinstance(obj, (str, Path)):
         
     | 
| 251 | 
         
            +
                        obj = DACFile.load(obj)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                    original_padding = self.padding
         
     | 
| 254 | 
         
            +
                    self.padding = obj.padding
         
     | 
| 255 | 
         
            +
             
     | 
| 256 | 
         
            +
                    range_fn = range if not verbose else tqdm.trange
         
     | 
| 257 | 
         
            +
                    codes = obj.codes
         
     | 
| 258 | 
         
            +
                    original_device = codes.device
         
     | 
| 259 | 
         
            +
                    chunk_length = obj.chunk_length
         
     | 
| 260 | 
         
            +
                    recons = []
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    for i in range_fn(0, codes.shape[-1], chunk_length):
         
     | 
| 263 | 
         
            +
                        c = codes[..., i : i + chunk_length].to(self.device)
         
     | 
| 264 | 
         
            +
                        z = self.quantizer.from_codes(c)[0]
         
     | 
| 265 | 
         
            +
                        r = self.decode(z)
         
     | 
| 266 | 
         
            +
                        recons.append(r.to(original_device))
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                    recons = torch.cat(recons, dim=-1)
         
     | 
| 269 | 
         
            +
                    recons = AudioSignal(recons, self.sample_rate)
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    resample_fn = recons.resample
         
     | 
| 272 | 
         
            +
                    loudness_fn = recons.loudness
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    # If audio is > 10 minutes long, use the ffmpeg versions
         
     | 
| 275 | 
         
            +
                    if recons.signal_duration >= 10 * 60 * 60:
         
     | 
| 276 | 
         
            +
                        resample_fn = recons.ffmpeg_resample
         
     | 
| 277 | 
         
            +
                        loudness_fn = recons.ffmpeg_loudness
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    recons.normalize(obj.input_db)
         
     | 
| 280 | 
         
            +
                    resample_fn(obj.sample_rate)
         
     | 
| 281 | 
         
            +
                    recons = recons[..., : obj.original_length]
         
     | 
| 282 | 
         
            +
                    loudness_fn()
         
     | 
| 283 | 
         
            +
                    recons.audio_data = recons.audio_data.reshape(-1, obj.channels, obj.original_length)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    self.padding = original_padding
         
     | 
| 286 | 
         
            +
                    return recons
         
     | 
    	
        higgs_audio/audio_processing/descriptaudiocodec/dac/model/dac.py
    ADDED
    
    | 
         @@ -0,0 +1,365 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
            +
            from typing import List
         
     | 
| 3 | 
         
            +
            from typing import Union
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            import torch
         
     | 
| 7 | 
         
            +
            from audiotools import AudioSignal
         
     | 
| 8 | 
         
            +
            from audiotools.ml import BaseModel
         
     | 
| 9 | 
         
            +
            from torch import nn
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            from .base import CodecMixin
         
     | 
| 12 | 
         
            +
            from dac.nn.layers import Snake1d
         
     | 
| 13 | 
         
            +
            from dac.nn.layers import WNConv1d
         
     | 
| 14 | 
         
            +
            from dac.nn.layers import WNConvTranspose1d
         
     | 
| 15 | 
         
            +
            from dac.nn.quantize import ResidualVectorQuantize
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def init_weights(m):
         
     | 
| 19 | 
         
            +
                if isinstance(m, nn.Conv1d):
         
     | 
| 20 | 
         
            +
                    nn.init.trunc_normal_(m.weight, std=0.02)
         
     | 
| 21 | 
         
            +
                    nn.init.constant_(m.bias, 0)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class ResidualUnit(nn.Module):
         
     | 
| 25 | 
         
            +
                def __init__(self, dim: int = 16, dilation: int = 1):
         
     | 
| 26 | 
         
            +
                    super().__init__()
         
     | 
| 27 | 
         
            +
                    pad = ((7 - 1) * dilation) // 2
         
     | 
| 28 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 29 | 
         
            +
                        Snake1d(dim),
         
     | 
| 30 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad),
         
     | 
| 31 | 
         
            +
                        Snake1d(dim),
         
     | 
| 32 | 
         
            +
                        WNConv1d(dim, dim, kernel_size=1),
         
     | 
| 33 | 
         
            +
                    )
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                def forward(self, x):
         
     | 
| 36 | 
         
            +
                    y = self.block(x)
         
     | 
| 37 | 
         
            +
                    pad = (x.shape[-1] - y.shape[-1]) // 2
         
     | 
| 38 | 
         
            +
                    if pad > 0:
         
     | 
| 39 | 
         
            +
                        x = x[..., pad:-pad]
         
     | 
| 40 | 
         
            +
                    return x + y
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class EncoderBlock(nn.Module):
         
     | 
| 44 | 
         
            +
                def __init__(self, dim: int = 16, stride: int = 1):
         
     | 
| 45 | 
         
            +
                    super().__init__()
         
     | 
| 46 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 47 | 
         
            +
                        ResidualUnit(dim // 2, dilation=1),
         
     | 
| 48 | 
         
            +
                        ResidualUnit(dim // 2, dilation=3),
         
     | 
| 49 | 
         
            +
                        ResidualUnit(dim // 2, dilation=9),
         
     | 
| 50 | 
         
            +
                        Snake1d(dim // 2),
         
     | 
| 51 | 
         
            +
                        WNConv1d(
         
     | 
| 52 | 
         
            +
                            dim // 2,
         
     | 
| 53 | 
         
            +
                            dim,
         
     | 
| 54 | 
         
            +
                            kernel_size=2 * stride,
         
     | 
| 55 | 
         
            +
                            stride=stride,
         
     | 
| 56 | 
         
            +
                            padding=math.ceil(stride / 2),
         
     | 
| 57 | 
         
            +
                        ),
         
     | 
| 58 | 
         
            +
                    )
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def forward(self, x):
         
     | 
| 61 | 
         
            +
                    return self.block(x)
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            class Encoder(nn.Module):
         
     | 
| 65 | 
         
            +
                def __init__(
         
     | 
| 66 | 
         
            +
                    self,
         
     | 
| 67 | 
         
            +
                    d_model: int = 64,
         
     | 
| 68 | 
         
            +
                    strides: list = [2, 4, 8, 8],
         
     | 
| 69 | 
         
            +
                    d_latent: int = 256,
         
     | 
| 70 | 
         
            +
                ):
         
     | 
| 71 | 
         
            +
                    super().__init__()
         
     | 
| 72 | 
         
            +
                    # Create first convolution
         
     | 
| 73 | 
         
            +
                    self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)]
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                    # Create EncoderBlocks that double channels as they downsample by `stride`
         
     | 
| 76 | 
         
            +
                    for stride in strides:
         
     | 
| 77 | 
         
            +
                        d_model *= 2
         
     | 
| 78 | 
         
            +
                        self.block += [EncoderBlock(d_model, stride=stride)]
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    # Create last convolution
         
     | 
| 81 | 
         
            +
                    self.block += [
         
     | 
| 82 | 
         
            +
                        Snake1d(d_model),
         
     | 
| 83 | 
         
            +
                        WNConv1d(d_model, d_latent, kernel_size=3, padding=1),
         
     | 
| 84 | 
         
            +
                    ]
         
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
                    # Wrap black into nn.Sequential
         
     | 
| 87 | 
         
            +
                    self.block = nn.Sequential(*self.block)
         
     | 
| 88 | 
         
            +
                    self.enc_dim = d_model
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                def forward(self, x):
         
     | 
| 91 | 
         
            +
                    return self.block(x)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            class DecoderBlock(nn.Module):
         
     | 
| 95 | 
         
            +
                def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1, out_pad=0):
         
     | 
| 96 | 
         
            +
                    super().__init__()
         
     | 
| 97 | 
         
            +
                    self.block = nn.Sequential(
         
     | 
| 98 | 
         
            +
                        Snake1d(input_dim),
         
     | 
| 99 | 
         
            +
                        WNConvTranspose1d(
         
     | 
| 100 | 
         
            +
                            input_dim,
         
     | 
| 101 | 
         
            +
                            output_dim,
         
     | 
| 102 | 
         
            +
                            kernel_size=2 * stride,
         
     | 
| 103 | 
         
            +
                            stride=stride,
         
     | 
| 104 | 
         
            +
                            padding=math.ceil(stride / 2),
         
     | 
| 105 | 
         
            +
                            output_padding=stride % 2,  # out_pad,
         
     | 
| 106 | 
         
            +
                        ),
         
     | 
| 107 | 
         
            +
                        ResidualUnit(output_dim, dilation=1),
         
     | 
| 108 | 
         
            +
                        ResidualUnit(output_dim, dilation=3),
         
     | 
| 109 | 
         
            +
                        ResidualUnit(output_dim, dilation=9),
         
     | 
| 110 | 
         
            +
                    )
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def forward(self, x):
         
     | 
| 113 | 
         
            +
                    return self.block(x)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
            class Decoder(nn.Module):
         
     | 
| 117 | 
         
            +
                def __init__(
         
     | 
| 118 | 
         
            +
                    self,
         
     | 
| 119 | 
         
            +
                    input_channel,
         
     | 
| 120 | 
         
            +
                    channels,
         
     | 
| 121 | 
         
            +
                    rates,
         
     | 
| 122 | 
         
            +
                    d_out: int = 1,
         
     | 
| 123 | 
         
            +
                ):
         
     | 
| 124 | 
         
            +
                    super().__init__()
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    # Add first conv layer
         
     | 
| 127 | 
         
            +
                    layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)]
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    # Add upsampling + MRF blocks
         
     | 
| 130 | 
         
            +
                    for i, stride in enumerate(rates):
         
     | 
| 131 | 
         
            +
                        input_dim = channels // 2**i
         
     | 
| 132 | 
         
            +
                        output_dim = channels // 2 ** (i + 1)
         
     | 
| 133 | 
         
            +
                        if i == 1:
         
     | 
| 134 | 
         
            +
                            out_pad = 1
         
     | 
| 135 | 
         
            +
                        else:
         
     | 
| 136 | 
         
            +
                            out_pad = 0
         
     | 
| 137 | 
         
            +
                        layers += [DecoderBlock(input_dim, output_dim, stride, out_pad)]
         
     | 
| 138 | 
         
            +
             
     | 
| 139 | 
         
            +
                    # Add final conv layer
         
     | 
| 140 | 
         
            +
                    layers += [
         
     | 
| 141 | 
         
            +
                        Snake1d(output_dim),
         
     | 
| 142 | 
         
            +
                        WNConv1d(output_dim, d_out, kernel_size=7, padding=3),
         
     | 
| 143 | 
         
            +
                        # nn.Tanh(),
         
     | 
| 144 | 
         
            +
                    ]
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
                    self.model = nn.Sequential(*layers)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                def forward(self, x):
         
     | 
| 149 | 
         
            +
                    return self.model(x)
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            class DAC(BaseModel, CodecMixin):
         
     | 
| 153 | 
         
            +
                def __init__(
         
     | 
| 154 | 
         
            +
                    self,
         
     | 
| 155 | 
         
            +
                    encoder_dim: int = 64,
         
     | 
| 156 | 
         
            +
                    encoder_rates: List[int] = [2, 4, 8, 8],
         
     | 
| 157 | 
         
            +
                    latent_dim: int = None,
         
     | 
| 158 | 
         
            +
                    decoder_dim: int = 1536,
         
     | 
| 159 | 
         
            +
                    decoder_rates: List[int] = [8, 8, 4, 2],
         
     | 
| 160 | 
         
            +
                    n_codebooks: int = 9,
         
     | 
| 161 | 
         
            +
                    codebook_size: int = 1024,
         
     | 
| 162 | 
         
            +
                    codebook_dim: Union[int, list] = 8,
         
     | 
| 163 | 
         
            +
                    quantizer_dropout: bool = False,
         
     | 
| 164 | 
         
            +
                    sample_rate: int = 44100,
         
     | 
| 165 | 
         
            +
                ):
         
     | 
| 166 | 
         
            +
                    super().__init__()
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    self.encoder_dim = encoder_dim
         
     | 
| 169 | 
         
            +
                    self.encoder_rates = encoder_rates
         
     | 
| 170 | 
         
            +
                    self.decoder_dim = decoder_dim
         
     | 
| 171 | 
         
            +
                    self.decoder_rates = decoder_rates
         
     | 
| 172 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                    if latent_dim is None:
         
     | 
| 175 | 
         
            +
                        latent_dim = encoder_dim * (2 ** len(encoder_rates))
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                    self.latent_dim = latent_dim
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                    self.hop_length = np.prod(encoder_rates)
         
     | 
| 180 | 
         
            +
                    self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim)
         
     | 
| 181 | 
         
            +
             
     | 
| 182 | 
         
            +
                    self.n_codebooks = n_codebooks
         
     | 
| 183 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 184 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 185 | 
         
            +
                    self.quantizer = ResidualVectorQuantize(
         
     | 
| 186 | 
         
            +
                        input_dim=latent_dim,
         
     | 
| 187 | 
         
            +
                        n_codebooks=n_codebooks,
         
     | 
| 188 | 
         
            +
                        codebook_size=codebook_size,
         
     | 
| 189 | 
         
            +
                        codebook_dim=codebook_dim,
         
     | 
| 190 | 
         
            +
                        quantizer_dropout=quantizer_dropout,
         
     | 
| 191 | 
         
            +
                    )
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    self.decoder = Decoder(
         
     | 
| 194 | 
         
            +
                        latent_dim,
         
     | 
| 195 | 
         
            +
                        decoder_dim,
         
     | 
| 196 | 
         
            +
                        decoder_rates,
         
     | 
| 197 | 
         
            +
                    )
         
     | 
| 198 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 199 | 
         
            +
                    self.apply(init_weights)
         
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
                    self.delay = self.get_delay()
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                def preprocess(self, audio_data, sample_rate):
         
     | 
| 204 | 
         
            +
                    if sample_rate is None:
         
     | 
| 205 | 
         
            +
                        sample_rate = self.sample_rate
         
     | 
| 206 | 
         
            +
                    assert sample_rate == self.sample_rate
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    length = audio_data.shape[-1]
         
     | 
| 209 | 
         
            +
                    right_pad = math.ceil(length / self.hop_length) * self.hop_length - length
         
     | 
| 210 | 
         
            +
                    audio_data = nn.functional.pad(audio_data, (0, right_pad))
         
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
                    return audio_data
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                def encode(
         
     | 
| 215 | 
         
            +
                    self,
         
     | 
| 216 | 
         
            +
                    audio_data: torch.Tensor,
         
     | 
| 217 | 
         
            +
                    n_quantizers: int = None,
         
     | 
| 218 | 
         
            +
                ):
         
     | 
| 219 | 
         
            +
                    """Encode given audio data and return quantized latent codes
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    Parameters
         
     | 
| 222 | 
         
            +
                    ----------
         
     | 
| 223 | 
         
            +
                    audio_data : Tensor[B x 1 x T]
         
     | 
| 224 | 
         
            +
                        Audio data to encode
         
     | 
| 225 | 
         
            +
                    n_quantizers : int, optional
         
     | 
| 226 | 
         
            +
                        Number of quantizers to use, by default None
         
     | 
| 227 | 
         
            +
                        If None, all quantizers are used.
         
     | 
| 228 | 
         
            +
             
     | 
| 229 | 
         
            +
                    Returns
         
     | 
| 230 | 
         
            +
                    -------
         
     | 
| 231 | 
         
            +
                    dict
         
     | 
| 232 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 233 | 
         
            +
                        "z" : Tensor[B x D x T]
         
     | 
| 234 | 
         
            +
                            Quantized continuous representation of input
         
     | 
| 235 | 
         
            +
                        "codes" : Tensor[B x N x T]
         
     | 
| 236 | 
         
            +
                            Codebook indices for each codebook
         
     | 
| 237 | 
         
            +
                            (quantized discrete representation of input)
         
     | 
| 238 | 
         
            +
                        "latents" : Tensor[B x N*D x T]
         
     | 
| 239 | 
         
            +
                            Projected latents (continuous representation of input before quantization)
         
     | 
| 240 | 
         
            +
                        "vq/commitment_loss" : Tensor[1]
         
     | 
| 241 | 
         
            +
                            Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 242 | 
         
            +
                            entries
         
     | 
| 243 | 
         
            +
                        "vq/codebook_loss" : Tensor[1]
         
     | 
| 244 | 
         
            +
                            Codebook loss to update the codebook
         
     | 
| 245 | 
         
            +
                        "length" : int
         
     | 
| 246 | 
         
            +
                            Number of samples in input audio
         
     | 
| 247 | 
         
            +
                    """
         
     | 
| 248 | 
         
            +
                    z = self.encoder(audio_data)
         
     | 
| 249 | 
         
            +
                    z, codes, latents, commitment_loss, codebook_loss = self.quantizer(z, n_quantizers)
         
     | 
| 250 | 
         
            +
                    return z, codes, latents, commitment_loss, codebook_loss
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                def decode(self, z: torch.Tensor):
         
     | 
| 253 | 
         
            +
                    """Decode given latent codes and return audio data
         
     | 
| 254 | 
         
            +
             
     | 
| 255 | 
         
            +
                    Parameters
         
     | 
| 256 | 
         
            +
                    ----------
         
     | 
| 257 | 
         
            +
                    z : Tensor[B x D x T]
         
     | 
| 258 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 259 | 
         
            +
                    length : int, optional
         
     | 
| 260 | 
         
            +
                        Number of samples in output audio, by default None
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    Returns
         
     | 
| 263 | 
         
            +
                    -------
         
     | 
| 264 | 
         
            +
                    dict
         
     | 
| 265 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 266 | 
         
            +
                        "audio" : Tensor[B x 1 x length]
         
     | 
| 267 | 
         
            +
                            Decoded audio data.
         
     | 
| 268 | 
         
            +
                    """
         
     | 
| 269 | 
         
            +
                    return self.decoder(z)
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                def forward(
         
     | 
| 272 | 
         
            +
                    self,
         
     | 
| 273 | 
         
            +
                    audio_data: torch.Tensor,
         
     | 
| 274 | 
         
            +
                    sample_rate: int = None,
         
     | 
| 275 | 
         
            +
                    n_quantizers: int = None,
         
     | 
| 276 | 
         
            +
                ):
         
     | 
| 277 | 
         
            +
                    """Model forward pass
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                    Parameters
         
     | 
| 280 | 
         
            +
                    ----------
         
     | 
| 281 | 
         
            +
                    audio_data : Tensor[B x 1 x T]
         
     | 
| 282 | 
         
            +
                        Audio data to encode
         
     | 
| 283 | 
         
            +
                    sample_rate : int, optional
         
     | 
| 284 | 
         
            +
                        Sample rate of audio data in Hz, by default None
         
     | 
| 285 | 
         
            +
                        If None, defaults to `self.sample_rate`
         
     | 
| 286 | 
         
            +
                    n_quantizers : int, optional
         
     | 
| 287 | 
         
            +
                        Number of quantizers to use, by default None.
         
     | 
| 288 | 
         
            +
                        If None, all quantizers are used.
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                    Returns
         
     | 
| 291 | 
         
            +
                    -------
         
     | 
| 292 | 
         
            +
                    dict
         
     | 
| 293 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 294 | 
         
            +
                        "z" : Tensor[B x D x T]
         
     | 
| 295 | 
         
            +
                            Quantized continuous representation of input
         
     | 
| 296 | 
         
            +
                        "codes" : Tensor[B x N x T]
         
     | 
| 297 | 
         
            +
                            Codebook indices for each codebook
         
     | 
| 298 | 
         
            +
                            (quantized discrete representation of input)
         
     | 
| 299 | 
         
            +
                        "latents" : Tensor[B x N*D x T]
         
     | 
| 300 | 
         
            +
                            Projected latents (continuous representation of input before quantization)
         
     | 
| 301 | 
         
            +
                        "vq/commitment_loss" : Tensor[1]
         
     | 
| 302 | 
         
            +
                            Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 303 | 
         
            +
                            entries
         
     | 
| 304 | 
         
            +
                        "vq/codebook_loss" : Tensor[1]
         
     | 
| 305 | 
         
            +
                            Codebook loss to update the codebook
         
     | 
| 306 | 
         
            +
                        "length" : int
         
     | 
| 307 | 
         
            +
                            Number of samples in input audio
         
     | 
| 308 | 
         
            +
                        "audio" : Tensor[B x 1 x length]
         
     | 
| 309 | 
         
            +
                            Decoded audio data.
         
     | 
| 310 | 
         
            +
                    """
         
     | 
| 311 | 
         
            +
                    length = audio_data.shape[-1]
         
     | 
| 312 | 
         
            +
                    audio_data = self.preprocess(audio_data, sample_rate)
         
     | 
| 313 | 
         
            +
                    z, codes, latents, commitment_loss, codebook_loss = self.encode(audio_data, n_quantizers)
         
     | 
| 314 | 
         
            +
             
     | 
| 315 | 
         
            +
                    x = self.decode(z)
         
     | 
| 316 | 
         
            +
                    return {
         
     | 
| 317 | 
         
            +
                        "audio": x[..., :length],
         
     | 
| 318 | 
         
            +
                        "z": z,
         
     | 
| 319 | 
         
            +
                        "codes": codes,
         
     | 
| 320 | 
         
            +
                        "latents": latents,
         
     | 
| 321 | 
         
            +
                        "vq/commitment_loss": commitment_loss,
         
     | 
| 322 | 
         
            +
                        "vq/codebook_loss": codebook_loss,
         
     | 
| 323 | 
         
            +
                    }
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 327 | 
         
            +
                import numpy as np
         
     | 
| 328 | 
         
            +
                from functools import partial
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                model = DAC().to("cpu")
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                for n, m in model.named_modules():
         
     | 
| 333 | 
         
            +
                    o = m.extra_repr()
         
     | 
| 334 | 
         
            +
                    p = sum([np.prod(p.size()) for p in m.parameters()])
         
     | 
| 335 | 
         
            +
                    fn = lambda o, p: o + f" {p / 1e6:<.3f}M params."
         
     | 
| 336 | 
         
            +
                    setattr(m, "extra_repr", partial(fn, o=o, p=p))
         
     | 
| 337 | 
         
            +
                print(model)
         
     | 
| 338 | 
         
            +
                print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
         
     | 
| 339 | 
         
            +
             
     | 
| 340 | 
         
            +
                length = 88200 * 2
         
     | 
| 341 | 
         
            +
                x = torch.randn(1, 1, length).to(model.device)
         
     | 
| 342 | 
         
            +
                x.requires_grad_(True)
         
     | 
| 343 | 
         
            +
                x.retain_grad()
         
     | 
| 344 | 
         
            +
             
     | 
| 345 | 
         
            +
                # Make a forward pass
         
     | 
| 346 | 
         
            +
                out = model(x)["audio"]
         
     | 
| 347 | 
         
            +
                print("Input shape:", x.shape)
         
     | 
| 348 | 
         
            +
                print("Output shape:", out.shape)
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                # Create gradient variable
         
     | 
| 351 | 
         
            +
                grad = torch.zeros_like(out)
         
     | 
| 352 | 
         
            +
                grad[:, :, grad.shape[-1] // 2] = 1
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                # Make a backward pass
         
     | 
| 355 | 
         
            +
                out.backward(grad)
         
     | 
| 356 | 
         
            +
             
     | 
| 357 | 
         
            +
                # Check non-zero values
         
     | 
| 358 | 
         
            +
                gradmap = x.grad.squeeze(0)
         
     | 
| 359 | 
         
            +
                gradmap = (gradmap != 0).sum(0)  # sum across features
         
     | 
| 360 | 
         
            +
                rf = (gradmap != 0).sum()
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                print(f"Receptive field: {rf.item()}")
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100)
         
     | 
| 365 | 
         
            +
                model.decompress(model.compress(x, verbose=True), verbose=True)
         
     | 
    	
        higgs_audio/audio_processing/descriptaudiocodec/dac/nn/layers.py
    ADDED
    
    | 
         @@ -0,0 +1,33 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import numpy as np
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            from einops import rearrange
         
     | 
| 6 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def WNConv1d(*args, **kwargs):
         
     | 
| 10 | 
         
            +
                return weight_norm(nn.Conv1d(*args, **kwargs))
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def WNConvTranspose1d(*args, **kwargs):
         
     | 
| 14 | 
         
            +
                return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            # Scripting this brings model speed up 1.4x
         
     | 
| 18 | 
         
            +
            @torch.jit.script
         
     | 
| 19 | 
         
            +
            def snake(x, alpha):
         
     | 
| 20 | 
         
            +
                shape = x.shape
         
     | 
| 21 | 
         
            +
                x = x.reshape(shape[0], shape[1], -1)
         
     | 
| 22 | 
         
            +
                x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
         
     | 
| 23 | 
         
            +
                x = x.reshape(shape)
         
     | 
| 24 | 
         
            +
                return x
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
            class Snake1d(nn.Module):
         
     | 
| 28 | 
         
            +
                def __init__(self, channels):
         
     | 
| 29 | 
         
            +
                    super().__init__()
         
     | 
| 30 | 
         
            +
                    self.alpha = nn.Parameter(torch.ones(1, channels, 1))
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                def forward(self, x):
         
     | 
| 33 | 
         
            +
                    return snake(x, self.alpha)
         
     | 
    	
        higgs_audio/audio_processing/descriptaudiocodec/dac/nn/quantize.py
    ADDED
    
    | 
         @@ -0,0 +1,251 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Union
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import numpy as np
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            import torch.nn as nn
         
     | 
| 6 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 7 | 
         
            +
            from einops import rearrange
         
     | 
| 8 | 
         
            +
            from torch.nn.utils import weight_norm
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            from dac.nn.layers import WNConv1d
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class VectorQuantize(nn.Module):
         
     | 
| 14 | 
         
            +
                """
         
     | 
| 15 | 
         
            +
                Implementation of VQ similar to Karpathy's repo:
         
     | 
| 16 | 
         
            +
                https://github.com/karpathy/deep-vector-quantization
         
     | 
| 17 | 
         
            +
                Additionally uses following tricks from Improved VQGAN
         
     | 
| 18 | 
         
            +
                (https://arxiv.org/pdf/2110.04627.pdf):
         
     | 
| 19 | 
         
            +
                    1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
         
     | 
| 20 | 
         
            +
                        for improved codebook usage
         
     | 
| 21 | 
         
            +
                    2. l2-normalized codes: Converts euclidean distance to cosine similarity which
         
     | 
| 22 | 
         
            +
                        improves training stability
         
     | 
| 23 | 
         
            +
                """
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
                def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int):
         
     | 
| 26 | 
         
            +
                    super().__init__()
         
     | 
| 27 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 28 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
         
     | 
| 31 | 
         
            +
                    self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)
         
     | 
| 32 | 
         
            +
                    self.codebook = nn.Embedding(codebook_size, codebook_dim)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def forward(self, z):
         
     | 
| 35 | 
         
            +
                    """Quantized the input tensor using a fixed codebook and returns
         
     | 
| 36 | 
         
            +
                    the corresponding codebook vectors
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    Parameters
         
     | 
| 39 | 
         
            +
                    ----------
         
     | 
| 40 | 
         
            +
                    z : Tensor[B x D x T]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    Returns
         
     | 
| 43 | 
         
            +
                    -------
         
     | 
| 44 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 45 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 46 | 
         
            +
                    Tensor[1]
         
     | 
| 47 | 
         
            +
                        Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 48 | 
         
            +
                        entries
         
     | 
| 49 | 
         
            +
                    Tensor[1]
         
     | 
| 50 | 
         
            +
                        Codebook loss to update the codebook
         
     | 
| 51 | 
         
            +
                    Tensor[B x T]
         
     | 
| 52 | 
         
            +
                        Codebook indices (quantized discrete representation of input)
         
     | 
| 53 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 54 | 
         
            +
                        Projected latents (continuous representation of input before quantization)
         
     | 
| 55 | 
         
            +
                    """
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                    # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
         
     | 
| 58 | 
         
            +
                    z_e = self.in_proj(z)  # z_e : (B x D x T)
         
     | 
| 59 | 
         
            +
                    z_q, indices = self.decode_latents(z_e)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
                    commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])
         
     | 
| 62 | 
         
            +
                    codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2])
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    z_q = z_e + (z_q - z_e).detach()  # noop in forward pass, straight-through gradient estimator in backward pass
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    z_q = self.out_proj(z_q)
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    return z_q, commitment_loss, codebook_loss, indices, z_e
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                def embed_code(self, embed_id):
         
     | 
| 71 | 
         
            +
                    return F.embedding(embed_id, self.codebook.weight)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def decode_code(self, embed_id):
         
     | 
| 74 | 
         
            +
                    return self.embed_code(embed_id).transpose(1, 2)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                def decode_latents(self, latents):
         
     | 
| 77 | 
         
            +
                    encodings = rearrange(latents, "b d t -> (b t) d")
         
     | 
| 78 | 
         
            +
                    codebook = self.codebook.weight  # codebook: (N x D)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    # L2 normalize encodings and codebook (ViT-VQGAN)
         
     | 
| 81 | 
         
            +
                    encodings = F.normalize(encodings)
         
     | 
| 82 | 
         
            +
                    codebook = F.normalize(codebook)
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    # Compute euclidean distance with codebook
         
     | 
| 85 | 
         
            +
                    dist = (
         
     | 
| 86 | 
         
            +
                        encodings.pow(2).sum(1, keepdim=True)
         
     | 
| 87 | 
         
            +
                        - 2 * encodings @ codebook.t()
         
     | 
| 88 | 
         
            +
                        + codebook.pow(2).sum(1, keepdim=True).t()
         
     | 
| 89 | 
         
            +
                    )
         
     | 
| 90 | 
         
            +
                    indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0))
         
     | 
| 91 | 
         
            +
                    z_q = self.decode_code(indices)
         
     | 
| 92 | 
         
            +
                    return z_q, indices
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
            class ResidualVectorQuantize(nn.Module):
         
     | 
| 96 | 
         
            +
                """
         
     | 
| 97 | 
         
            +
                Introduced in SoundStream: An end2end neural audio codec
         
     | 
| 98 | 
         
            +
                https://arxiv.org/abs/2107.03312
         
     | 
| 99 | 
         
            +
                """
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                def __init__(
         
     | 
| 102 | 
         
            +
                    self,
         
     | 
| 103 | 
         
            +
                    input_dim: int = 512,
         
     | 
| 104 | 
         
            +
                    n_codebooks: int = 9,
         
     | 
| 105 | 
         
            +
                    codebook_size: int = 1024,
         
     | 
| 106 | 
         
            +
                    codebook_dim: Union[int, list] = 8,
         
     | 
| 107 | 
         
            +
                    quantizer_dropout: float = 0.0,
         
     | 
| 108 | 
         
            +
                ):
         
     | 
| 109 | 
         
            +
                    super().__init__()
         
     | 
| 110 | 
         
            +
                    if isinstance(codebook_dim, int):
         
     | 
| 111 | 
         
            +
                        codebook_dim = [codebook_dim for _ in range(n_codebooks)]
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                    self.n_codebooks = n_codebooks
         
     | 
| 114 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 115 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    self.quantizers = nn.ModuleList(
         
     | 
| 118 | 
         
            +
                        [VectorQuantize(input_dim, codebook_size, codebook_dim[i]) for i in range(n_codebooks)]
         
     | 
| 119 | 
         
            +
                    )
         
     | 
| 120 | 
         
            +
                    self.quantizer_dropout = quantizer_dropout
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
                def forward(self, z, n_quantizers: int = None):
         
     | 
| 123 | 
         
            +
                    """Quantized the input tensor using a fixed set of `n` codebooks and returns
         
     | 
| 124 | 
         
            +
                    the corresponding codebook vectors
         
     | 
| 125 | 
         
            +
                    Parameters
         
     | 
| 126 | 
         
            +
                    ----------
         
     | 
| 127 | 
         
            +
                    z : Tensor[B x D x T]
         
     | 
| 128 | 
         
            +
                    n_quantizers : int, optional
         
     | 
| 129 | 
         
            +
                        No. of quantizers to use
         
     | 
| 130 | 
         
            +
                        (n_quantizers < self.n_codebooks ex: for quantizer dropout)
         
     | 
| 131 | 
         
            +
                        Note: if `self.quantizer_dropout` is True, this argument is ignored
         
     | 
| 132 | 
         
            +
                            when in training mode, and a random number of quantizers is used.
         
     | 
| 133 | 
         
            +
                    Returns
         
     | 
| 134 | 
         
            +
                    -------
         
     | 
| 135 | 
         
            +
                    dict
         
     | 
| 136 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                        "z" : Tensor[B x D x T]
         
     | 
| 139 | 
         
            +
                            Quantized continuous representation of input
         
     | 
| 140 | 
         
            +
                        "codes" : Tensor[B x N x T]
         
     | 
| 141 | 
         
            +
                            Codebook indices for each codebook
         
     | 
| 142 | 
         
            +
                            (quantized discrete representation of input)
         
     | 
| 143 | 
         
            +
                        "latents" : Tensor[B x N*D x T]
         
     | 
| 144 | 
         
            +
                            Projected latents (continuous representation of input before quantization)
         
     | 
| 145 | 
         
            +
                        "vq/commitment_loss" : Tensor[1]
         
     | 
| 146 | 
         
            +
                            Commitment loss to train encoder to predict vectors closer to codebook
         
     | 
| 147 | 
         
            +
                            entries
         
     | 
| 148 | 
         
            +
                        "vq/codebook_loss" : Tensor[1]
         
     | 
| 149 | 
         
            +
                            Codebook loss to update the codebook
         
     | 
| 150 | 
         
            +
                    """
         
     | 
| 151 | 
         
            +
                    z_q = 0
         
     | 
| 152 | 
         
            +
                    residual = z
         
     | 
| 153 | 
         
            +
                    commitment_loss = 0
         
     | 
| 154 | 
         
            +
                    codebook_loss = 0
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                    codebook_indices = []
         
     | 
| 157 | 
         
            +
                    latents = []
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    if n_quantizers is None:
         
     | 
| 160 | 
         
            +
                        n_quantizers = self.n_codebooks
         
     | 
| 161 | 
         
            +
                    if self.training:
         
     | 
| 162 | 
         
            +
                        n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
         
     | 
| 163 | 
         
            +
                        dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
         
     | 
| 164 | 
         
            +
                        n_dropout = int(z.shape[0] * self.quantizer_dropout)
         
     | 
| 165 | 
         
            +
                        n_quantizers[:n_dropout] = dropout[:n_dropout]
         
     | 
| 166 | 
         
            +
                        n_quantizers = n_quantizers.to(z.device)
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                    for i, quantizer in enumerate(self.quantizers):
         
     | 
| 169 | 
         
            +
                        if self.training is False and i >= n_quantizers:
         
     | 
| 170 | 
         
            +
                            break
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                        z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer(residual)
         
     | 
| 173 | 
         
            +
             
     | 
| 174 | 
         
            +
                        # Create mask to apply quantizer dropout
         
     | 
| 175 | 
         
            +
                        mask = torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
         
     | 
| 176 | 
         
            +
                        z_q = z_q + z_q_i * mask[:, None, None]
         
     | 
| 177 | 
         
            +
                        residual = residual - z_q_i
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                        # Sum losses
         
     | 
| 180 | 
         
            +
                        commitment_loss += (commitment_loss_i * mask).mean()
         
     | 
| 181 | 
         
            +
                        codebook_loss += (codebook_loss_i * mask).mean()
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                        codebook_indices.append(indices_i)
         
     | 
| 184 | 
         
            +
                        latents.append(z_e_i)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
                    codes = torch.stack(codebook_indices, dim=1)
         
     | 
| 187 | 
         
            +
                    latents = torch.cat(latents, dim=1)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    return z_q, codes, latents, commitment_loss, codebook_loss
         
     | 
| 190 | 
         
            +
             
     | 
| 191 | 
         
            +
                def from_codes(self, codes: torch.Tensor):
         
     | 
| 192 | 
         
            +
                    """Given the quantized codes, reconstruct the continuous representation
         
     | 
| 193 | 
         
            +
                    Parameters
         
     | 
| 194 | 
         
            +
                    ----------
         
     | 
| 195 | 
         
            +
                    codes : Tensor[B x N x T]
         
     | 
| 196 | 
         
            +
                        Quantized discrete representation of input
         
     | 
| 197 | 
         
            +
                    Returns
         
     | 
| 198 | 
         
            +
                    -------
         
     | 
| 199 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 200 | 
         
            +
                        Quantized continuous representation of input
         
     | 
| 201 | 
         
            +
                    """
         
     | 
| 202 | 
         
            +
                    z_q = 0.0
         
     | 
| 203 | 
         
            +
                    z_p = []
         
     | 
| 204 | 
         
            +
                    n_codebooks = codes.shape[1]
         
     | 
| 205 | 
         
            +
                    for i in range(n_codebooks):
         
     | 
| 206 | 
         
            +
                        z_p_i = self.quantizers[i].decode_code(codes[:, i, :])
         
     | 
| 207 | 
         
            +
                        z_p.append(z_p_i)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                        z_q_i = self.quantizers[i].out_proj(z_p_i)
         
     | 
| 210 | 
         
            +
                        z_q = z_q + z_q_i
         
     | 
| 211 | 
         
            +
                    return z_q, torch.cat(z_p, dim=1), codes
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def from_latents(self, latents: torch.Tensor):
         
     | 
| 214 | 
         
            +
                    """Given the unquantized latents, reconstruct the
         
     | 
| 215 | 
         
            +
                    continuous representation after quantization.
         
     | 
| 216 | 
         
            +
             
     | 
| 217 | 
         
            +
                    Parameters
         
     | 
| 218 | 
         
            +
                    ----------
         
     | 
| 219 | 
         
            +
                    latents : Tensor[B x N x T]
         
     | 
| 220 | 
         
            +
                        Continuous representation of input after projection
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    Returns
         
     | 
| 223 | 
         
            +
                    -------
         
     | 
| 224 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 225 | 
         
            +
                        Quantized representation of full-projected space
         
     | 
| 226 | 
         
            +
                    Tensor[B x D x T]
         
     | 
| 227 | 
         
            +
                        Quantized representation of latent space
         
     | 
| 228 | 
         
            +
                    """
         
     | 
| 229 | 
         
            +
                    z_q = 0
         
     | 
| 230 | 
         
            +
                    z_p = []
         
     | 
| 231 | 
         
            +
                    codes = []
         
     | 
| 232 | 
         
            +
                    dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers])
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[0]
         
     | 
| 235 | 
         
            +
                    for i in range(n_codebooks):
         
     | 
| 236 | 
         
            +
                        j, k = dims[i], dims[i + 1]
         
     | 
| 237 | 
         
            +
                        z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :])
         
     | 
| 238 | 
         
            +
                        z_p.append(z_p_i)
         
     | 
| 239 | 
         
            +
                        codes.append(codes_i)
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                        z_q_i = self.quantizers[i].out_proj(z_p_i)
         
     | 
| 242 | 
         
            +
                        z_q = z_q + z_q_i
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 248 | 
         
            +
                rvq = ResidualVectorQuantize(quantizer_dropout=True)
         
     | 
| 249 | 
         
            +
                x = torch.randn(16, 512, 80)
         
     | 
| 250 | 
         
            +
                y = rvq(x)
         
     | 
| 251 | 
         
            +
                print(y["latents"].shape)
         
     | 
    	
        higgs_audio/audio_processing/higgs_audio_tokenizer.py
    ADDED
    
    | 
         @@ -0,0 +1,341 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Based on code from: https://github.com/zhenye234/xcodec
         
     | 
| 2 | 
         
            +
            # Licensed under MIT License
         
     | 
| 3 | 
         
            +
            # Modifications by BosonAI
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import math
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.nn as nn
         
     | 
| 9 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 10 | 
         
            +
            from typing import Optional, Union, Sequence
         
     | 
| 11 | 
         
            +
            import numpy as np
         
     | 
| 12 | 
         
            +
            from transformers import AutoModel
         
     | 
| 13 | 
         
            +
            import torchaudio
         
     | 
| 14 | 
         
            +
            import json
         
     | 
| 15 | 
         
            +
            import librosa
         
     | 
| 16 | 
         
            +
            from huggingface_hub import snapshot_download
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            from vector_quantize_pytorch import ResidualFSQ
         
     | 
| 19 | 
         
            +
            from .descriptaudiocodec.dac.model import dac as dac2
         
     | 
| 20 | 
         
            +
            from .quantization.vq import ResidualVectorQuantizer
         
     | 
| 21 | 
         
            +
            from .semantic_module import Encoder, Decoder
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            class EncodedResult:
         
     | 
| 25 | 
         
            +
                def __init__(self, audio_codes):
         
     | 
| 26 | 
         
            +
                    self.audio_codes = audio_codes
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class HiggsAudioFeatureExtractor(nn.Module):
         
     | 
| 30 | 
         
            +
                def __init__(self, sampling_rate=16000):
         
     | 
| 31 | 
         
            +
                    super().__init__()
         
     | 
| 32 | 
         
            +
                    self.sampling_rate = sampling_rate
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                def forward(self, raw_audio, sampling_rate=16000, return_tensors="pt"):
         
     | 
| 35 | 
         
            +
                    # Convert from librosa to torch
         
     | 
| 36 | 
         
            +
                    audio_signal = torch.tensor(raw_audio)
         
     | 
| 37 | 
         
            +
                    audio_signal = audio_signal.unsqueeze(0)
         
     | 
| 38 | 
         
            +
                    if len(audio_signal.shape) < 3:
         
     | 
| 39 | 
         
            +
                        audio_signal = audio_signal.unsqueeze(0)
         
     | 
| 40 | 
         
            +
                    return {"input_values": audio_signal}
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
            class HiggsAudioTokenizer(nn.Module):
         
     | 
| 44 | 
         
            +
                def __init__(
         
     | 
| 45 | 
         
            +
                    self,
         
     | 
| 46 | 
         
            +
                    n_filters: int = 32,
         
     | 
| 47 | 
         
            +
                    D: int = 128,
         
     | 
| 48 | 
         
            +
                    target_bandwidths: Sequence[Union[int, float]] = [1, 1.5, 2, 4, 6],
         
     | 
| 49 | 
         
            +
                    ratios: Sequence[int] = [8, 5, 4, 2],  #  downsampling by 320
         
     | 
| 50 | 
         
            +
                    sample_rate: int = 16000,
         
     | 
| 51 | 
         
            +
                    bins: int = 1024,
         
     | 
| 52 | 
         
            +
                    n_q: int = 8,
         
     | 
| 53 | 
         
            +
                    codebook_dim: int = None,
         
     | 
| 54 | 
         
            +
                    normalize: bool = False,
         
     | 
| 55 | 
         
            +
                    causal: bool = False,
         
     | 
| 56 | 
         
            +
                    semantic_techer: str = "hubert_base_general",
         
     | 
| 57 | 
         
            +
                    last_layer_semantic: bool = True,
         
     | 
| 58 | 
         
            +
                    merge_mode: str = "concat",
         
     | 
| 59 | 
         
            +
                    downsample_mode: str = "step_down",
         
     | 
| 60 | 
         
            +
                    semantic_mode: str = "classic",
         
     | 
| 61 | 
         
            +
                    vq_scale: int = 1,
         
     | 
| 62 | 
         
            +
                    semantic_sample_rate: int = None,
         
     | 
| 63 | 
         
            +
                    device: str = "cuda",
         
     | 
| 64 | 
         
            +
                ):
         
     | 
| 65 | 
         
            +
                    super().__init__()
         
     | 
| 66 | 
         
            +
                    self.hop_length = np.prod(ratios)
         
     | 
| 67 | 
         
            +
                    self.semantic_techer = semantic_techer
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    self.frame_rate = math.ceil(sample_rate / np.prod(ratios))  # 50 Hz
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                    self.target_bandwidths = target_bandwidths
         
     | 
| 72 | 
         
            +
                    self.n_q = n_q
         
     | 
| 73 | 
         
            +
                    self.sample_rate = sample_rate
         
     | 
| 74 | 
         
            +
                    self.encoder = dac2.Encoder(64, ratios, D)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    self.decoder_2 = dac2.Decoder(D, 1024, ratios)
         
     | 
| 77 | 
         
            +
                    self.last_layer_semantic = last_layer_semantic
         
     | 
| 78 | 
         
            +
                    self.device = device
         
     | 
| 79 | 
         
            +
                    if semantic_techer == "hubert_base":
         
     | 
| 80 | 
         
            +
                        self.semantic_model = AutoModel.from_pretrained("facebook/hubert-base-ls960")
         
     | 
| 81 | 
         
            +
                        self.semantic_sample_rate = 16000
         
     | 
| 82 | 
         
            +
                        self.semantic_dim = 768
         
     | 
| 83 | 
         
            +
                        self.encoder_semantic_dim = 768
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    elif semantic_techer == "wavlm_base_plus":
         
     | 
| 86 | 
         
            +
                        self.semantic_model = AutoModel.from_pretrained("microsoft/wavlm-base-plus")
         
     | 
| 87 | 
         
            +
                        self.semantic_sample_rate = 16000
         
     | 
| 88 | 
         
            +
                        self.semantic_dim = 768
         
     | 
| 89 | 
         
            +
                        self.encoder_semantic_dim = 768
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    elif semantic_techer == "hubert_base_general":
         
     | 
| 92 | 
         
            +
                        self.semantic_model = AutoModel.from_pretrained("ZhenYe234/hubert_base_general_audio")
         
     | 
| 93 | 
         
            +
                        self.semantic_sample_rate = 16000
         
     | 
| 94 | 
         
            +
                        self.semantic_dim = 768
         
     | 
| 95 | 
         
            +
                        self.encoder_semantic_dim = 768
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
                    # Overwrite semantic model sr to ensure semantic_downsample_factor is an integer
         
     | 
| 98 | 
         
            +
                    if semantic_sample_rate is not None:
         
     | 
| 99 | 
         
            +
                        self.semantic_sample_rate = semantic_sample_rate
         
     | 
| 100 | 
         
            +
             
     | 
| 101 | 
         
            +
                    self.semantic_model.eval()
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    # make the semantic model parameters do not need gradient
         
     | 
| 104 | 
         
            +
                    for param in self.semantic_model.parameters():
         
     | 
| 105 | 
         
            +
                        param.requires_grad = False
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    self.semantic_downsample_factor = int(self.hop_length / (self.sample_rate / self.semantic_sample_rate) / 320)
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                    self.quantizer_dim = int((D + self.encoder_semantic_dim) // vq_scale)
         
     | 
| 110 | 
         
            +
                    self.encoder_semantic = Encoder(input_channels=self.semantic_dim, encode_channels=self.encoder_semantic_dim)
         
     | 
| 111 | 
         
            +
                    self.decoder_semantic = Decoder(
         
     | 
| 112 | 
         
            +
                        code_dim=self.encoder_semantic_dim,
         
     | 
| 113 | 
         
            +
                        output_channels=self.semantic_dim,
         
     | 
| 114 | 
         
            +
                        decode_channels=self.semantic_dim,
         
     | 
| 115 | 
         
            +
                    )
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    # out_D=D+768
         
     | 
| 118 | 
         
            +
                    if isinstance(bins, int):  # RVQ
         
     | 
| 119 | 
         
            +
                        self.quantizer = ResidualVectorQuantizer(
         
     | 
| 120 | 
         
            +
                            dimension=self.quantizer_dim,
         
     | 
| 121 | 
         
            +
                            codebook_dim=codebook_dim,
         
     | 
| 122 | 
         
            +
                            n_q=n_q,
         
     | 
| 123 | 
         
            +
                            bins=bins,
         
     | 
| 124 | 
         
            +
                        )
         
     | 
| 125 | 
         
            +
                        self.quantizer_type = "RVQ"
         
     | 
| 126 | 
         
            +
                    else:  # RFSQ
         
     | 
| 127 | 
         
            +
                        self.quantizer = ResidualFSQ(dim=self.quantizer_dim, levels=bins, num_quantizers=n_q)
         
     | 
| 128 | 
         
            +
                        self.quantizer_type = "RFSQ"
         
     | 
| 129 | 
         
            +
             
     | 
| 130 | 
         
            +
                    self.fc_prior = nn.Linear(D + self.encoder_semantic_dim, self.quantizer_dim)
         
     | 
| 131 | 
         
            +
                    self.fc_post1 = nn.Linear(self.quantizer_dim, self.encoder_semantic_dim)
         
     | 
| 132 | 
         
            +
                    self.fc_post2 = nn.Linear(self.quantizer_dim, D)
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                    self.downsample_mode = downsample_mode
         
     | 
| 135 | 
         
            +
                    if downsample_mode == "avg":
         
     | 
| 136 | 
         
            +
                        self.semantic_pooling = nn.AvgPool1d(
         
     | 
| 137 | 
         
            +
                            kernel_size=self.semantic_downsample_factor,
         
     | 
| 138 | 
         
            +
                            stride=self.semantic_downsample_factor,
         
     | 
| 139 | 
         
            +
                        )
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    self.audio_tokenizer_feature_extractor = HiggsAudioFeatureExtractor(sampling_rate=self.sample_rate)
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                @property
         
     | 
| 144 | 
         
            +
                def tps(self):
         
     | 
| 145 | 
         
            +
                    return self.frame_rate
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
                @property
         
     | 
| 148 | 
         
            +
                def sampling_rate(self):
         
     | 
| 149 | 
         
            +
                    return self.sample_rate
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                @property
         
     | 
| 152 | 
         
            +
                def num_codebooks(self):
         
     | 
| 153 | 
         
            +
                    return self.n_q
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                @property
         
     | 
| 156 | 
         
            +
                def codebook_size(self):
         
     | 
| 157 | 
         
            +
                    return self.quantizer_dim
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                def get_last_layer(self):
         
     | 
| 160 | 
         
            +
                    return self.decoder.layers[-1].weight
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                def calculate_rec_loss(self, rec, target):
         
     | 
| 163 | 
         
            +
                    target = target / target.norm(dim=-1, keepdim=True)
         
     | 
| 164 | 
         
            +
                    rec = rec / rec.norm(dim=-1, keepdim=True)
         
     | 
| 165 | 
         
            +
                    rec_loss = (1 - (target * rec).sum(-1)).mean()
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                    return rec_loss
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                @torch.no_grad()
         
     | 
| 170 | 
         
            +
                def get_regress_target(self, x):
         
     | 
| 171 | 
         
            +
                    x = torchaudio.functional.resample(x, self.sample_rate, self.semantic_sample_rate)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                    if (
         
     | 
| 174 | 
         
            +
                        self.semantic_techer == "hubert_base"
         
     | 
| 175 | 
         
            +
                        or self.semantic_techer == "hubert_base_general"
         
     | 
| 176 | 
         
            +
                        or self.semantic_techer == "wavlm_base_plus"
         
     | 
| 177 | 
         
            +
                    ):
         
     | 
| 178 | 
         
            +
                        x = x[:, 0, :]
         
     | 
| 179 | 
         
            +
                        x = F.pad(x, (160, 160))
         
     | 
| 180 | 
         
            +
                        target = self.semantic_model(x, output_hidden_states=True).hidden_states
         
     | 
| 181 | 
         
            +
                        target = torch.stack(target, dim=1)  # .transpose(-1, -2)#.flatten(start_dim=1, end_dim=2)
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                        # average for all layers
         
     | 
| 184 | 
         
            +
                        target = target.mean(1)
         
     | 
| 185 | 
         
            +
                        # target = target[9]
         
     | 
| 186 | 
         
            +
                        # if self.hop_length > 320:
         
     | 
| 187 | 
         
            +
                        #     target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    elif self.semantic_techer == "w2v_bert2":
         
     | 
| 190 | 
         
            +
                        target = self.semantic_model(x)
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                    elif self.semantic_techer.startswith("whisper"):
         
     | 
| 193 | 
         
            +
                        if self.last_layer_semantic:
         
     | 
| 194 | 
         
            +
                            target = self.semantic_model(x, avg_layers=False)
         
     | 
| 195 | 
         
            +
                        else:
         
     | 
| 196 | 
         
            +
                            target = self.semantic_model(x, avg_layers=True)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    elif self.semantic_techer.startswith("mert_music"):
         
     | 
| 199 | 
         
            +
                        if self.last_layer_semantic:
         
     | 
| 200 | 
         
            +
                            target = self.semantic_model(x, avg_layers=False)
         
     | 
| 201 | 
         
            +
                        else:
         
     | 
| 202 | 
         
            +
                            target = self.semantic_model(x, avg_layers=True)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    elif self.semantic_techer.startswith("qwen_audio_omni"):
         
     | 
| 205 | 
         
            +
                        target = self.semantic_model(x)
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                    if self.downsample_mode == "step_down":
         
     | 
| 208 | 
         
            +
                        if self.semantic_downsample_factor > 1:
         
     | 
| 209 | 
         
            +
                            target = target[:, :: self.semantic_downsample_factor, :]
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                    elif self.downsample_mode == "avg":
         
     | 
| 212 | 
         
            +
                        target = self.semantic_pooling(target.transpose(1, 2)).transpose(1, 2)
         
     | 
| 213 | 
         
            +
                    return target
         
     | 
| 214 | 
         
            +
             
     | 
| 215 | 
         
            +
                def forward(self, x: torch.Tensor, bw: int):
         
     | 
| 216 | 
         
            +
                    e_semantic_input = self.get_regress_target(x).detach()
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
         
     | 
| 219 | 
         
            +
                    e_acoustic = self.encoder(x)
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                    e = torch.cat([e_acoustic, e_semantic], dim=1)
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    e = self.fc_prior(e.transpose(1, 2))
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    if self.quantizer_type == "RVQ":
         
     | 
| 226 | 
         
            +
                        e = e.transpose(1, 2)
         
     | 
| 227 | 
         
            +
                        quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
         
     | 
| 228 | 
         
            +
                        quantized = quantized.transpose(1, 2)
         
     | 
| 229 | 
         
            +
                    else:
         
     | 
| 230 | 
         
            +
                        quantized, codes = self.quantizer(e)
         
     | 
| 231 | 
         
            +
                        commit_loss = torch.tensor(0.0)
         
     | 
| 232 | 
         
            +
             
     | 
| 233 | 
         
            +
                    quantized_semantic = self.fc_post1(quantized).transpose(1, 2)
         
     | 
| 234 | 
         
            +
                    quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    o = self.decoder_2(quantized_acoustic)
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                    o_semantic = self.decoder_semantic(quantized_semantic)
         
     | 
| 239 | 
         
            +
                    semantic_recon_loss = F.mse_loss(e_semantic_input.transpose(1, 2).detach(), o_semantic)
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    return o, commit_loss, semantic_recon_loss, None
         
     | 
| 242 | 
         
            +
             
     | 
| 243 | 
         
            +
                def encode(
         
     | 
| 244 | 
         
            +
                    self,
         
     | 
| 245 | 
         
            +
                    audio_path_or_wv,
         
     | 
| 246 | 
         
            +
                    sr=None,
         
     | 
| 247 | 
         
            +
                    loudness_normalize=False,
         
     | 
| 248 | 
         
            +
                    loudness_threshold=-23.0,
         
     | 
| 249 | 
         
            +
                ):
         
     | 
| 250 | 
         
            +
                    if isinstance(audio_path_or_wv, str):
         
     | 
| 251 | 
         
            +
                        wv, sr = librosa.load(audio_path_or_wv, mono=True, sr=None)
         
     | 
| 252 | 
         
            +
                    else:
         
     | 
| 253 | 
         
            +
                        wv = audio_path_or_wv
         
     | 
| 254 | 
         
            +
                        assert sr is not None
         
     | 
| 255 | 
         
            +
                    if loudness_normalize:
         
     | 
| 256 | 
         
            +
                        import pyloudnorm as pyln
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                        meter = pyln.Meter(sr)
         
     | 
| 259 | 
         
            +
                        l = meter.integrated_loudness(wv)
         
     | 
| 260 | 
         
            +
                        wv = pyln.normalize.loudness(wv, l, loudness_threshold)
         
     | 
| 261 | 
         
            +
                    if sr != self.sampling_rate:
         
     | 
| 262 | 
         
            +
                        wv = librosa.resample(wv, orig_sr=sr, target_sr=self.sampling_rate)
         
     | 
| 263 | 
         
            +
                    if self.audio_tokenizer_feature_extractor is not None:
         
     | 
| 264 | 
         
            +
                        inputs = self.audio_tokenizer_feature_extractor(
         
     | 
| 265 | 
         
            +
                            raw_audio=wv,
         
     | 
| 266 | 
         
            +
                            sampling_rate=self.audio_tokenizer_feature_extractor.sampling_rate,
         
     | 
| 267 | 
         
            +
                            return_tensors="pt",
         
     | 
| 268 | 
         
            +
                        )
         
     | 
| 269 | 
         
            +
                        input_values = inputs["input_values"].to(self.device)
         
     | 
| 270 | 
         
            +
                    else:
         
     | 
| 271 | 
         
            +
                        input_values = torch.from_numpy(wv).float().unsqueeze(0)
         
     | 
| 272 | 
         
            +
                    with torch.no_grad():
         
     | 
| 273 | 
         
            +
                        encoder_outputs = self._xcodec_encode(input_values)
         
     | 
| 274 | 
         
            +
                        vq_code = encoder_outputs.audio_codes[0]
         
     | 
| 275 | 
         
            +
                    return vq_code
         
     | 
| 276 | 
         
            +
             
     | 
| 277 | 
         
            +
                def _xcodec_encode(self, x: torch.Tensor, target_bw: Optional[int] = None) -> torch.Tensor:
         
     | 
| 278 | 
         
            +
                    bw = target_bw
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    e_semantic_input = self.get_regress_target(x).detach()
         
     | 
| 281 | 
         
            +
             
     | 
| 282 | 
         
            +
                    e_semantic = self.encoder_semantic(e_semantic_input.transpose(1, 2))
         
     | 
| 283 | 
         
            +
                    e_acoustic = self.encoder(x)
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    if e_acoustic.shape[2] != e_semantic.shape[2]:
         
     | 
| 286 | 
         
            +
                        pad_size = 160 * self.semantic_downsample_factor
         
     | 
| 287 | 
         
            +
                        e_acoustic = self.encoder(F.pad(x[:, 0, :], (pad_size, pad_size)).unsqueeze(0))
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                    if e_acoustic.shape[2] != e_semantic.shape[2]:
         
     | 
| 290 | 
         
            +
                        if e_acoustic.shape[2] > e_semantic.shape[2]:
         
     | 
| 291 | 
         
            +
                            e_acoustic = e_acoustic[:, :, : e_semantic.shape[2]]
         
     | 
| 292 | 
         
            +
                        else:
         
     | 
| 293 | 
         
            +
                            e_semantic = e_semantic[:, :, : e_acoustic.shape[2]]
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
                    e = torch.cat([e_acoustic, e_semantic], dim=1)
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                    e = self.fc_prior(e.transpose(1, 2))
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
                    if self.quantizer_type == "RVQ":
         
     | 
| 300 | 
         
            +
                        e = e.transpose(1, 2)
         
     | 
| 301 | 
         
            +
                        quantized, codes, bandwidth, commit_loss = self.quantizer(e, self.frame_rate, bw)
         
     | 
| 302 | 
         
            +
                        codes = codes.permute(1, 0, 2)
         
     | 
| 303 | 
         
            +
                    else:
         
     | 
| 304 | 
         
            +
                        quantized, codes = self.quantizer(e)
         
     | 
| 305 | 
         
            +
                        codes = codes.permute(0, 2, 1)
         
     | 
| 306 | 
         
            +
             
     | 
| 307 | 
         
            +
                    # return codes
         
     | 
| 308 | 
         
            +
                    return EncodedResult(codes)
         
     | 
| 309 | 
         
            +
             
     | 
| 310 | 
         
            +
                def decode(self, vq_code: torch.Tensor) -> torch.Tensor:
         
     | 
| 311 | 
         
            +
                    if self.quantizer_type == "RVQ":
         
     | 
| 312 | 
         
            +
                        vq_code = vq_code.permute(1, 0, 2)
         
     | 
| 313 | 
         
            +
                        quantized = self.quantizer.decode(vq_code)
         
     | 
| 314 | 
         
            +
                        quantized = quantized.transpose(1, 2)
         
     | 
| 315 | 
         
            +
                    else:
         
     | 
| 316 | 
         
            +
                        vq_code = vq_code.permute(0, 2, 1)
         
     | 
| 317 | 
         
            +
                        quantized = self.quantizer.get_output_from_indices(vq_code)
         
     | 
| 318 | 
         
            +
                    quantized_acoustic = self.fc_post2(quantized).transpose(1, 2)
         
     | 
| 319 | 
         
            +
             
     | 
| 320 | 
         
            +
                    o = self.decoder_2(quantized_acoustic)
         
     | 
| 321 | 
         
            +
                    return o.cpu().numpy()
         
     | 
| 322 | 
         
            +
             
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
            def load_higgs_audio_tokenizer(tokenizer_name_or_path, device="cuda"):
         
     | 
| 325 | 
         
            +
                is_local = os.path.exists(tokenizer_name_or_path)
         
     | 
| 326 | 
         
            +
                if not is_local:
         
     | 
| 327 | 
         
            +
                    tokenizer_path = snapshot_download(tokenizer_name_or_path)
         
     | 
| 328 | 
         
            +
                else:
         
     | 
| 329 | 
         
            +
                    tokenizer_path = tokenizer_name_or_path
         
     | 
| 330 | 
         
            +
                config_path = os.path.join(tokenizer_path, "config.json")
         
     | 
| 331 | 
         
            +
                model_path = os.path.join(tokenizer_path, "model.pth")
         
     | 
| 332 | 
         
            +
                config = json.load(open(config_path))
         
     | 
| 333 | 
         
            +
                model = HiggsAudioTokenizer(
         
     | 
| 334 | 
         
            +
                    **config,
         
     | 
| 335 | 
         
            +
                    device=device,
         
     | 
| 336 | 
         
            +
                )
         
     | 
| 337 | 
         
            +
                parameter_dict = torch.load(model_path, map_location=device)
         
     | 
| 338 | 
         
            +
                model.load_state_dict(parameter_dict, strict=False)
         
     | 
| 339 | 
         
            +
                model.to(device)
         
     | 
| 340 | 
         
            +
                model.eval()
         
     | 
| 341 | 
         
            +
                return model
         
     | 
    	
        higgs_audio/audio_processing/quantization/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            # flake8: noqa
         
     | 
| 8 | 
         
            +
            from .vq import QuantizedResult, ResidualVectorQuantizer
         
     | 
    	
        higgs_audio/audio_processing/quantization/ac.py
    ADDED
    
    | 
         @@ -0,0 +1,301 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            """Arithmetic coder."""
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import io
         
     | 
| 10 | 
         
            +
            import math
         
     | 
| 11 | 
         
            +
            import random
         
     | 
| 12 | 
         
            +
            import typing as tp
         
     | 
| 13 | 
         
            +
            import torch
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            from ..binary import BitPacker, BitUnpacker
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            def build_stable_quantized_cdf(
         
     | 
| 19 | 
         
            +
                pdf: torch.Tensor,
         
     | 
| 20 | 
         
            +
                total_range_bits: int,
         
     | 
| 21 | 
         
            +
                roundoff: float = 1e-8,
         
     | 
| 22 | 
         
            +
                min_range: int = 2,
         
     | 
| 23 | 
         
            +
                check: bool = True,
         
     | 
| 24 | 
         
            +
            ) -> torch.Tensor:
         
     | 
| 25 | 
         
            +
                """Turn the given PDF into a quantized CDF that splits
         
     | 
| 26 | 
         
            +
                [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
         
     | 
| 27 | 
         
            +
                to the PDF.
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                Args:
         
     | 
| 30 | 
         
            +
                    pdf (torch.Tensor): probability distribution, shape should be `[N]`.
         
     | 
| 31 | 
         
            +
                    total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
         
     | 
| 32 | 
         
            +
                        during the coding process is `[0, 2 ** total_range_bits - 1]`.
         
     | 
| 33 | 
         
            +
                    roundoff (float): will round the pdf up to that level to remove difference coming
         
     | 
| 34 | 
         
            +
                    from e.g. evaluating the Language Model on different architectures.
         
     | 
| 35 | 
         
            +
                    min_range (int): minimum range width. Should always be at least 2 for numerical
         
     | 
| 36 | 
         
            +
                        stability. Use this to avoid pathological behavior is a value
         
     | 
| 37 | 
         
            +
                        that is expected to be rare actually happens in real life.
         
     | 
| 38 | 
         
            +
                    check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
         
     | 
| 39 | 
         
            +
                """
         
     | 
| 40 | 
         
            +
                pdf = pdf.detach()
         
     | 
| 41 | 
         
            +
                if roundoff:
         
     | 
| 42 | 
         
            +
                    pdf = (pdf / roundoff).floor() * roundoff
         
     | 
| 43 | 
         
            +
                # interpolate with uniform distribution to achieve desired minimum probability.
         
     | 
| 44 | 
         
            +
                total_range = 2**total_range_bits
         
     | 
| 45 | 
         
            +
                cardinality = len(pdf)
         
     | 
| 46 | 
         
            +
                alpha = min_range * cardinality / total_range
         
     | 
| 47 | 
         
            +
                assert alpha <= 1, "you must reduce min_range"
         
     | 
| 48 | 
         
            +
                ranges = (((1 - alpha) * total_range) * pdf).floor().long()
         
     | 
| 49 | 
         
            +
                ranges += min_range
         
     | 
| 50 | 
         
            +
                quantized_cdf = torch.cumsum(ranges, dim=-1)
         
     | 
| 51 | 
         
            +
                if min_range < 2:
         
     | 
| 52 | 
         
            +
                    raise ValueError("min_range must be at least 2.")
         
     | 
| 53 | 
         
            +
                if check:
         
     | 
| 54 | 
         
            +
                    assert quantized_cdf[-1] <= 2**total_range_bits, quantized_cdf[-1]
         
     | 
| 55 | 
         
            +
                    if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
         
     | 
| 56 | 
         
            +
                        raise ValueError("You must increase your total_range_bits.")
         
     | 
| 57 | 
         
            +
                return quantized_cdf
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            class ArithmeticCoder:
         
     | 
| 61 | 
         
            +
                """ArithmeticCoder,
         
     | 
| 62 | 
         
            +
                Let us take a distribution `p` over `N` symbols, and assume we have a stream
         
     | 
| 63 | 
         
            +
                of random variables `s_t` sampled from `p`. Let us assume that we have a budget
         
     | 
| 64 | 
         
            +
                of `B` bits that we can afford to write on device. There are `2**B` possible numbers,
         
     | 
| 65 | 
         
            +
                corresponding to the range `[0, 2 ** B - 1]`. We can map each of those number to a single
         
     | 
| 66 | 
         
            +
                sequence `(s_t)` by doing the following:
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                1) Initialize the current range to` [0 ** 2 B - 1]`.
         
     | 
| 69 | 
         
            +
                2) For each time step t, split the current range into contiguous chunks,
         
     | 
| 70 | 
         
            +
                    one for each possible outcome, with size roughly proportional to `p`.
         
     | 
| 71 | 
         
            +
                    For instance, if `p = [0.75, 0.25]`, and the range is `[0, 3]`, the chunks
         
     | 
| 72 | 
         
            +
                    would be `{[0, 2], [3, 3]}`.
         
     | 
| 73 | 
         
            +
                3) Select the chunk corresponding to `s_t`, and replace the current range with this.
         
     | 
| 74 | 
         
            +
                4) When done encoding all the values, just select any value remaining in the range.
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                You will notice that this procedure can fail: for instance if at any point in time
         
     | 
| 77 | 
         
            +
                the range is smaller than `N`, then we can no longer assign a non-empty chunk to each
         
     | 
| 78 | 
         
            +
                possible outcome. Intuitively, the more likely a value is, the less the range width
         
     | 
| 79 | 
         
            +
                will reduce, and the longer we can go on encoding values. This makes sense: for any efficient
         
     | 
| 80 | 
         
            +
                coding scheme, likely outcomes would take less bits, and more of them can be coded
         
     | 
| 81 | 
         
            +
                with a fixed budget.
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
                In practice, we do not know `B` ahead of time, but we have a way to inject new bits
         
     | 
| 84 | 
         
            +
                when the current range decreases below a given limit (given by `total_range_bits`), without
         
     | 
| 85 | 
         
            +
                having to redo all the computations. If we encode mostly likely values, we will seldom
         
     | 
| 86 | 
         
            +
                need to inject new bits, but a single rare value can deplete our stock of entropy!
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                In this explanation, we assumed that the distribution `p` was constant. In fact, the present
         
     | 
| 89 | 
         
            +
                code works for any sequence `(p_t)` possibly different for each timestep.
         
     | 
| 90 | 
         
            +
                We also assume that `s_t ~ p_t`, but that doesn't need to be true, although the smaller
         
     | 
| 91 | 
         
            +
                the KL between the true distribution and `p_t`, the most efficient the coding will be.
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                Args:
         
     | 
| 94 | 
         
            +
                    fo (IO[bytes]): file-like object to which the bytes will be written to.
         
     | 
| 95 | 
         
            +
                    total_range_bits (int): the range `M` described above is `2 ** total_range_bits.
         
     | 
| 96 | 
         
            +
                        Any time the current range width fall under this limit, new bits will
         
     | 
| 97 | 
         
            +
                        be injected to rescale the initial range.
         
     | 
| 98 | 
         
            +
                """
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
         
     | 
| 101 | 
         
            +
                    assert total_range_bits <= 30
         
     | 
| 102 | 
         
            +
                    self.total_range_bits = total_range_bits
         
     | 
| 103 | 
         
            +
                    self.packer = BitPacker(bits=1, fo=fo)  # we push single bits at a time.
         
     | 
| 104 | 
         
            +
                    self.low: int = 0
         
     | 
| 105 | 
         
            +
                    self.high: int = 0
         
     | 
| 106 | 
         
            +
                    self.max_bit: int = -1
         
     | 
| 107 | 
         
            +
                    self._dbg: tp.List[tp.Any] = []
         
     | 
| 108 | 
         
            +
                    self._dbg2: tp.List[tp.Any] = []
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                @property
         
     | 
| 111 | 
         
            +
                def delta(self) -> int:
         
     | 
| 112 | 
         
            +
                    """Return the current range width."""
         
     | 
| 113 | 
         
            +
                    return self.high - self.low + 1
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                def _flush_common_prefix(self):
         
     | 
| 116 | 
         
            +
                    # If self.low and self.high start with the sames bits,
         
     | 
| 117 | 
         
            +
                    # those won't change anymore as we always just increase the range
         
     | 
| 118 | 
         
            +
                    # by powers of 2, and we can flush them out to the bit stream.
         
     | 
| 119 | 
         
            +
                    assert self.high >= self.low, (self.low, self.high)
         
     | 
| 120 | 
         
            +
                    assert self.high < 2 ** (self.max_bit + 1)
         
     | 
| 121 | 
         
            +
                    while self.max_bit >= 0:
         
     | 
| 122 | 
         
            +
                        b1 = self.low >> self.max_bit
         
     | 
| 123 | 
         
            +
                        b2 = self.high >> self.max_bit
         
     | 
| 124 | 
         
            +
                        if b1 == b2:
         
     | 
| 125 | 
         
            +
                            self.low -= b1 << self.max_bit
         
     | 
| 126 | 
         
            +
                            self.high -= b1 << self.max_bit
         
     | 
| 127 | 
         
            +
                            assert self.high >= self.low, (self.high, self.low, self.max_bit)
         
     | 
| 128 | 
         
            +
                            assert self.low >= 0
         
     | 
| 129 | 
         
            +
                            self.max_bit -= 1
         
     | 
| 130 | 
         
            +
                            self.packer.push(b1)
         
     | 
| 131 | 
         
            +
                        else:
         
     | 
| 132 | 
         
            +
                            break
         
     | 
| 133 | 
         
            +
             
     | 
| 134 | 
         
            +
                def push(self, symbol: int, quantized_cdf: torch.Tensor):
         
     | 
| 135 | 
         
            +
                    """Push the given symbol on the stream, flushing out bits
         
     | 
| 136 | 
         
            +
                    if possible.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    Args:
         
     | 
| 139 | 
         
            +
                        symbol (int): symbol to encode with the AC.
         
     | 
| 140 | 
         
            +
                        quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
         
     | 
| 141 | 
         
            +
                            to build this from your pdf estimate.
         
     | 
| 142 | 
         
            +
                    """
         
     | 
| 143 | 
         
            +
                    while self.delta < 2**self.total_range_bits:
         
     | 
| 144 | 
         
            +
                        self.low *= 2
         
     | 
| 145 | 
         
            +
                        self.high = self.high * 2 + 1
         
     | 
| 146 | 
         
            +
                        self.max_bit += 1
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
         
     | 
| 149 | 
         
            +
                    range_high = quantized_cdf[symbol].item() - 1
         
     | 
| 150 | 
         
            +
                    effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
         
     | 
| 151 | 
         
            +
                    effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
         
     | 
| 152 | 
         
            +
                    assert self.low <= self.high
         
     | 
| 153 | 
         
            +
                    self.high = self.low + effective_high
         
     | 
| 154 | 
         
            +
                    self.low = self.low + effective_low
         
     | 
| 155 | 
         
            +
                    assert self.low <= self.high, (
         
     | 
| 156 | 
         
            +
                        effective_low,
         
     | 
| 157 | 
         
            +
                        effective_high,
         
     | 
| 158 | 
         
            +
                        range_low,
         
     | 
| 159 | 
         
            +
                        range_high,
         
     | 
| 160 | 
         
            +
                    )
         
     | 
| 161 | 
         
            +
                    self._dbg.append((self.low, self.high))
         
     | 
| 162 | 
         
            +
                    self._dbg2.append((self.low, self.high))
         
     | 
| 163 | 
         
            +
                    outs = self._flush_common_prefix()
         
     | 
| 164 | 
         
            +
                    assert self.low <= self.high
         
     | 
| 165 | 
         
            +
                    assert self.max_bit >= -1
         
     | 
| 166 | 
         
            +
                    assert self.max_bit <= 61, self.max_bit
         
     | 
| 167 | 
         
            +
                    return outs
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                def flush(self):
         
     | 
| 170 | 
         
            +
                    """Flush the remaining information to the stream."""
         
     | 
| 171 | 
         
            +
                    while self.max_bit >= 0:
         
     | 
| 172 | 
         
            +
                        b1 = (self.low >> self.max_bit) & 1
         
     | 
| 173 | 
         
            +
                        self.packer.push(b1)
         
     | 
| 174 | 
         
            +
                        self.max_bit -= 1
         
     | 
| 175 | 
         
            +
                    self.packer.flush()
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
            class ArithmeticDecoder:
         
     | 
| 179 | 
         
            +
                """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                Note that this must be called with **exactly** the same parameters and sequence
         
     | 
| 182 | 
         
            +
                of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                If the AC encoder current range is [L, H], with `L` and `H` having the some common
         
     | 
| 185 | 
         
            +
                prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
         
     | 
| 186 | 
         
            +
                For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
         
     | 
| 187 | 
         
            +
                `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
         
     | 
| 188 | 
         
            +
                for a specific sequence of symbols and a binary-search allows us to decode those symbols.
         
     | 
| 189 | 
         
            +
                At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
         
     | 
| 190 | 
         
            +
                and we will need to read new bits from the stream and repeat the process.
         
     | 
| 191 | 
         
            +
             
     | 
| 192 | 
         
            +
                """
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                def __init__(self, fo: tp.IO[bytes], total_range_bits: int = 24):
         
     | 
| 195 | 
         
            +
                    self.total_range_bits = total_range_bits
         
     | 
| 196 | 
         
            +
                    self.low: int = 0
         
     | 
| 197 | 
         
            +
                    self.high: int = 0
         
     | 
| 198 | 
         
            +
                    self.current: int = 0
         
     | 
| 199 | 
         
            +
                    self.max_bit: int = -1
         
     | 
| 200 | 
         
            +
                    self.unpacker = BitUnpacker(bits=1, fo=fo)  # we pull single bits at a time.
         
     | 
| 201 | 
         
            +
                    # Following is for debugging
         
     | 
| 202 | 
         
            +
                    self._dbg: tp.List[tp.Any] = []
         
     | 
| 203 | 
         
            +
                    self._dbg2: tp.List[tp.Any] = []
         
     | 
| 204 | 
         
            +
                    self._last: tp.Any = None
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                @property
         
     | 
| 207 | 
         
            +
                def delta(self) -> int:
         
     | 
| 208 | 
         
            +
                    return self.high - self.low + 1
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                def _flush_common_prefix(self):
         
     | 
| 211 | 
         
            +
                    # Given the current range [L, H], if both have a common prefix,
         
     | 
| 212 | 
         
            +
                    # we know we can remove it from our representation to avoid handling large numbers.
         
     | 
| 213 | 
         
            +
                    while self.max_bit >= 0:
         
     | 
| 214 | 
         
            +
                        b1 = self.low >> self.max_bit
         
     | 
| 215 | 
         
            +
                        b2 = self.high >> self.max_bit
         
     | 
| 216 | 
         
            +
                        if b1 == b2:
         
     | 
| 217 | 
         
            +
                            self.low -= b1 << self.max_bit
         
     | 
| 218 | 
         
            +
                            self.high -= b1 << self.max_bit
         
     | 
| 219 | 
         
            +
                            self.current -= b1 << self.max_bit
         
     | 
| 220 | 
         
            +
                            assert self.high >= self.low
         
     | 
| 221 | 
         
            +
                            assert self.low >= 0
         
     | 
| 222 | 
         
            +
                            self.max_bit -= 1
         
     | 
| 223 | 
         
            +
                        else:
         
     | 
| 224 | 
         
            +
                            break
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
         
     | 
| 227 | 
         
            +
                    """Pull a symbol, reading as many bits from the stream as required.
         
     | 
| 228 | 
         
            +
                    This returns `None` when the stream has been exhausted.
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                    Args:
         
     | 
| 231 | 
         
            +
                        quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
         
     | 
| 232 | 
         
            +
                            to build this from your pdf estimate. This must be **exatly**
         
     | 
| 233 | 
         
            +
                            the same cdf as the one used at encoding time.
         
     | 
| 234 | 
         
            +
                    """
         
     | 
| 235 | 
         
            +
                    while self.delta < 2**self.total_range_bits:
         
     | 
| 236 | 
         
            +
                        bit = self.unpacker.pull()
         
     | 
| 237 | 
         
            +
                        if bit is None:
         
     | 
| 238 | 
         
            +
                            return None
         
     | 
| 239 | 
         
            +
                        self.low *= 2
         
     | 
| 240 | 
         
            +
                        self.high = self.high * 2 + 1
         
     | 
| 241 | 
         
            +
                        self.current = self.current * 2 + bit
         
     | 
| 242 | 
         
            +
                        self.max_bit += 1
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                    def bin_search(low_idx: int, high_idx: int):
         
     | 
| 245 | 
         
            +
                        # Binary search is not just for coding interviews :)
         
     | 
| 246 | 
         
            +
                        if high_idx < low_idx:
         
     | 
| 247 | 
         
            +
                            raise RuntimeError("Binary search failed")
         
     | 
| 248 | 
         
            +
                        mid = (low_idx + high_idx) // 2
         
     | 
| 249 | 
         
            +
                        range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
         
     | 
| 250 | 
         
            +
                        range_high = quantized_cdf[mid].item() - 1
         
     | 
| 251 | 
         
            +
                        effective_low = int(math.ceil(range_low * (self.delta / (2**self.total_range_bits))))
         
     | 
| 252 | 
         
            +
                        effective_high = int(math.floor(range_high * (self.delta / (2**self.total_range_bits))))
         
     | 
| 253 | 
         
            +
                        low = effective_low + self.low
         
     | 
| 254 | 
         
            +
                        high = effective_high + self.low
         
     | 
| 255 | 
         
            +
                        if self.current >= low:
         
     | 
| 256 | 
         
            +
                            if self.current <= high:
         
     | 
| 257 | 
         
            +
                                return (mid, low, high, self.current)
         
     | 
| 258 | 
         
            +
                            else:
         
     | 
| 259 | 
         
            +
                                return bin_search(mid + 1, high_idx)
         
     | 
| 260 | 
         
            +
                        else:
         
     | 
| 261 | 
         
            +
                            return bin_search(low_idx, mid - 1)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    self._last = (self.low, self.high, self.current, self.max_bit)
         
     | 
| 264 | 
         
            +
                    sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
         
     | 
| 265 | 
         
            +
                    self._dbg.append((self.low, self.high, self.current))
         
     | 
| 266 | 
         
            +
                    self._flush_common_prefix()
         
     | 
| 267 | 
         
            +
                    self._dbg2.append((self.low, self.high, self.current))
         
     | 
| 268 | 
         
            +
             
     | 
| 269 | 
         
            +
                    return sym
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
            def test():
         
     | 
| 273 | 
         
            +
                torch.manual_seed(1234)
         
     | 
| 274 | 
         
            +
                random.seed(1234)
         
     | 
| 275 | 
         
            +
                for _ in range(4):
         
     | 
| 276 | 
         
            +
                    pdfs = []
         
     | 
| 277 | 
         
            +
                    cardinality = random.randrange(4000)
         
     | 
| 278 | 
         
            +
                    steps = random.randrange(100, 500)
         
     | 
| 279 | 
         
            +
                    fo = io.BytesIO()
         
     | 
| 280 | 
         
            +
                    encoder = ArithmeticCoder(fo)
         
     | 
| 281 | 
         
            +
                    symbols = []
         
     | 
| 282 | 
         
            +
                    for step in range(steps):
         
     | 
| 283 | 
         
            +
                        pdf = torch.softmax(torch.randn(cardinality), dim=0)
         
     | 
| 284 | 
         
            +
                        pdfs.append(pdf)
         
     | 
| 285 | 
         
            +
                        q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
         
     | 
| 286 | 
         
            +
                        symbol = torch.multinomial(pdf, 1).item()
         
     | 
| 287 | 
         
            +
                        symbols.append(symbol)
         
     | 
| 288 | 
         
            +
                        encoder.push(symbol, q_cdf)
         
     | 
| 289 | 
         
            +
                    encoder.flush()
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    fo.seek(0)
         
     | 
| 292 | 
         
            +
                    decoder = ArithmeticDecoder(fo)
         
     | 
| 293 | 
         
            +
                    for idx, (pdf, symbol) in enumerate(zip(pdfs, symbols)):
         
     | 
| 294 | 
         
            +
                        q_cdf = build_stable_quantized_cdf(pdf, encoder.total_range_bits)
         
     | 
| 295 | 
         
            +
                        decoded_symbol = decoder.pull(q_cdf)
         
     | 
| 296 | 
         
            +
                        assert decoded_symbol == symbol, idx
         
     | 
| 297 | 
         
            +
                    assert decoder.pull(torch.zeros(1)) is None
         
     | 
| 298 | 
         
            +
             
     | 
| 299 | 
         
            +
             
     | 
| 300 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 301 | 
         
            +
                test()
         
     | 
    	
        higgs_audio/audio_processing/quantization/core_vq.py
    ADDED
    
    | 
         @@ -0,0 +1,360 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
            #
         
     | 
| 7 | 
         
            +
            # This implementation is inspired from
         
     | 
| 8 | 
         
            +
            # https://github.com/lucidrains/vector-quantize-pytorch
         
     | 
| 9 | 
         
            +
            # which is released under MIT License. Hereafter, the original license:
         
     | 
| 10 | 
         
            +
            # MIT License
         
     | 
| 11 | 
         
            +
            #
         
     | 
| 12 | 
         
            +
            # Copyright (c) 2020 Phil Wang
         
     | 
| 13 | 
         
            +
            #
         
     | 
| 14 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 15 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 16 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 17 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 18 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 19 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 20 | 
         
            +
            #
         
     | 
| 21 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 22 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 23 | 
         
            +
            #
         
     | 
| 24 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 25 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 26 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 27 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 28 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 29 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 30 | 
         
            +
            # SOFTWARE.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            """Core vector quantization implementation."""
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
            import typing as tp
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            from einops import rearrange, repeat
         
     | 
| 37 | 
         
            +
            import torch
         
     | 
| 38 | 
         
            +
            from torch import nn
         
     | 
| 39 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            from xcodec.quantization.distrib import broadcast_tensors, rank
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            def default(val: tp.Any, d: tp.Any) -> tp.Any:
         
     | 
| 45 | 
         
            +
                return val if val is not None else d
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            def ema_inplace(moving_avg, new, decay: float):
         
     | 
| 49 | 
         
            +
                moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
            def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
         
     | 
| 53 | 
         
            +
                return (x + epsilon) / (x.sum() + n_categories * epsilon)
         
     | 
| 54 | 
         
            +
             
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
            def uniform_init(*shape: int):
         
     | 
| 57 | 
         
            +
                t = torch.empty(shape)
         
     | 
| 58 | 
         
            +
                nn.init.kaiming_uniform_(t)
         
     | 
| 59 | 
         
            +
                return t
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            def sample_vectors(samples, num: int):
         
     | 
| 63 | 
         
            +
                num_samples, device = samples.shape[0], samples.device
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                if num_samples >= num:
         
     | 
| 66 | 
         
            +
                    indices = torch.randperm(num_samples, device=device)[:num]
         
     | 
| 67 | 
         
            +
                else:
         
     | 
| 68 | 
         
            +
                    indices = torch.randint(0, num_samples, (num,), device=device)
         
     | 
| 69 | 
         
            +
             
     | 
| 70 | 
         
            +
                return samples[indices]
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def kmeans(samples, num_clusters: int, num_iters: int = 10):
         
     | 
| 74 | 
         
            +
                dim, dtype = samples.shape[-1], samples.dtype
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                means = sample_vectors(samples, num_clusters)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                for _ in range(num_iters):
         
     | 
| 79 | 
         
            +
                    diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
         
     | 
| 80 | 
         
            +
                    dists = -(diffs**2).sum(dim=-1)
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
                    buckets = dists.max(dim=-1).indices
         
     | 
| 83 | 
         
            +
                    bins = torch.bincount(buckets, minlength=num_clusters)
         
     | 
| 84 | 
         
            +
                    zero_mask = bins == 0
         
     | 
| 85 | 
         
            +
                    bins_min_clamped = bins.masked_fill(zero_mask, 1)
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                    new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
         
     | 
| 88 | 
         
            +
                    new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
         
     | 
| 89 | 
         
            +
                    new_means = new_means / bins_min_clamped[..., None]
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
                    means = torch.where(zero_mask[..., None], means, new_means)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                return means, bins
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            class EuclideanCodebook(nn.Module):
         
     | 
| 97 | 
         
            +
                """Codebook with Euclidean distance.
         
     | 
| 98 | 
         
            +
                Args:
         
     | 
| 99 | 
         
            +
                    dim (int): Dimension.
         
     | 
| 100 | 
         
            +
                    codebook_size (int): Codebook size.
         
     | 
| 101 | 
         
            +
                    kmeans_init (bool): Whether to use k-means to initialize the codebooks.
         
     | 
| 102 | 
         
            +
                        If set to true, run the k-means algorithm on the first training batch and use
         
     | 
| 103 | 
         
            +
                        the learned centroids as initialization.
         
     | 
| 104 | 
         
            +
                    kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
         
     | 
| 105 | 
         
            +
                    decay (float): Decay for exponential moving average over the codebooks.
         
     | 
| 106 | 
         
            +
                    epsilon (float): Epsilon value for numerical stability.
         
     | 
| 107 | 
         
            +
                    threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
         
     | 
| 108 | 
         
            +
                        that have an exponential moving average cluster size less than the specified threshold with
         
     | 
| 109 | 
         
            +
                        randomly selected vector from the current batch.
         
     | 
| 110 | 
         
            +
                """
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                def __init__(
         
     | 
| 113 | 
         
            +
                    self,
         
     | 
| 114 | 
         
            +
                    dim: int,
         
     | 
| 115 | 
         
            +
                    codebook_size: int,
         
     | 
| 116 | 
         
            +
                    kmeans_init: int = False,
         
     | 
| 117 | 
         
            +
                    kmeans_iters: int = 10,
         
     | 
| 118 | 
         
            +
                    decay: float = 0.99,
         
     | 
| 119 | 
         
            +
                    epsilon: float = 1e-5,
         
     | 
| 120 | 
         
            +
                    threshold_ema_dead_code: int = 2,
         
     | 
| 121 | 
         
            +
                ):
         
     | 
| 122 | 
         
            +
                    super().__init__()
         
     | 
| 123 | 
         
            +
                    self.decay = decay
         
     | 
| 124 | 
         
            +
                    init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
         
     | 
| 125 | 
         
            +
                    embed = init_fn(codebook_size, dim)
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    self.kmeans_iters = kmeans_iters
         
     | 
| 130 | 
         
            +
                    self.epsilon = epsilon
         
     | 
| 131 | 
         
            +
                    self.threshold_ema_dead_code = threshold_ema_dead_code
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    self.register_buffer("inited", torch.Tensor([not kmeans_init]))
         
     | 
| 134 | 
         
            +
                    self.register_buffer("cluster_size", torch.zeros(codebook_size))
         
     | 
| 135 | 
         
            +
                    self.register_buffer("embed", embed)
         
     | 
| 136 | 
         
            +
                    self.register_buffer("embed_avg", embed.clone())
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                @torch.jit.ignore
         
     | 
| 139 | 
         
            +
                def init_embed_(self, data):
         
     | 
| 140 | 
         
            +
                    if self.inited:
         
     | 
| 141 | 
         
            +
                        return
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                    embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
         
     | 
| 144 | 
         
            +
                    self.embed.data.copy_(embed)
         
     | 
| 145 | 
         
            +
                    self.embed_avg.data.copy_(embed.clone())
         
     | 
| 146 | 
         
            +
                    self.cluster_size.data.copy_(cluster_size)
         
     | 
| 147 | 
         
            +
                    self.inited.data.copy_(torch.Tensor([True]))
         
     | 
| 148 | 
         
            +
                    # Make sure all buffers across workers are in sync after initialization
         
     | 
| 149 | 
         
            +
                    broadcast_tensors(self.buffers())
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                def replace_(self, samples, mask):
         
     | 
| 152 | 
         
            +
                    modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
         
     | 
| 153 | 
         
            +
                    self.embed.data.copy_(modified_codebook)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                def expire_codes_(self, batch_samples):
         
     | 
| 156 | 
         
            +
                    if self.threshold_ema_dead_code == 0:
         
     | 
| 157 | 
         
            +
                        return
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    expired_codes = self.cluster_size < self.threshold_ema_dead_code
         
     | 
| 160 | 
         
            +
                    if not torch.any(expired_codes):
         
     | 
| 161 | 
         
            +
                        return
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                    batch_samples = rearrange(batch_samples, "... d -> (...) d")
         
     | 
| 164 | 
         
            +
                    self.replace_(batch_samples, mask=expired_codes)
         
     | 
| 165 | 
         
            +
                    broadcast_tensors(self.buffers())
         
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
                def preprocess(self, x):
         
     | 
| 168 | 
         
            +
                    x = rearrange(x, "... d -> (...) d")
         
     | 
| 169 | 
         
            +
                    return x
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                def quantize(self, x):
         
     | 
| 172 | 
         
            +
                    embed = self.embed.t()
         
     | 
| 173 | 
         
            +
                    dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
         
     | 
| 174 | 
         
            +
                    embed_ind = dist.max(dim=-1).indices
         
     | 
| 175 | 
         
            +
                    return embed_ind
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                def postprocess_emb(self, embed_ind, shape):
         
     | 
| 178 | 
         
            +
                    return embed_ind.view(*shape[:-1])
         
     | 
| 179 | 
         
            +
             
     | 
| 180 | 
         
            +
                def dequantize(self, embed_ind):
         
     | 
| 181 | 
         
            +
                    quantize = F.embedding(embed_ind, self.embed)  # get embedding based on index
         
     | 
| 182 | 
         
            +
                    return quantize
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                def encode(self, x):
         
     | 
| 185 | 
         
            +
                    shape = x.shape
         
     | 
| 186 | 
         
            +
                    # pre-process
         
     | 
| 187 | 
         
            +
                    x = self.preprocess(x)
         
     | 
| 188 | 
         
            +
                    # quantize
         
     | 
| 189 | 
         
            +
                    embed_ind = self.quantize(x)  # get index based on Euclidean distance
         
     | 
| 190 | 
         
            +
                    # post-process
         
     | 
| 191 | 
         
            +
                    embed_ind = self.postprocess_emb(embed_ind, shape)
         
     | 
| 192 | 
         
            +
                    return embed_ind
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                def decode(self, embed_ind):
         
     | 
| 195 | 
         
            +
                    quantize = self.dequantize(embed_ind)
         
     | 
| 196 | 
         
            +
                    return quantize
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                def forward(self, x):
         
     | 
| 199 | 
         
            +
                    shape, dtype = x.shape, x.dtype
         
     | 
| 200 | 
         
            +
                    x = self.preprocess(x)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                    self.init_embed_(x)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    embed_ind = self.quantize(x)
         
     | 
| 205 | 
         
            +
                    embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
         
     | 
| 206 | 
         
            +
                    embed_ind = self.postprocess_emb(embed_ind, shape)
         
     | 
| 207 | 
         
            +
                    quantize = self.dequantize(embed_ind)
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                    if self.training:
         
     | 
| 210 | 
         
            +
                        # We do the expiry of code at that point as buffers are in sync
         
     | 
| 211 | 
         
            +
                        # and all the workers will take the same decision.
         
     | 
| 212 | 
         
            +
                        self.expire_codes_(x)
         
     | 
| 213 | 
         
            +
                        ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
         
     | 
| 214 | 
         
            +
                        embed_sum = x.t() @ embed_onehot
         
     | 
| 215 | 
         
            +
                        ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
         
     | 
| 216 | 
         
            +
                        cluster_size = (
         
     | 
| 217 | 
         
            +
                            laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
         
     | 
| 218 | 
         
            +
                        )
         
     | 
| 219 | 
         
            +
                        embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
         
     | 
| 220 | 
         
            +
                        self.embed.data.copy_(embed_normalized)
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                    return quantize, embed_ind
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
            class VectorQuantization(nn.Module):
         
     | 
| 226 | 
         
            +
                """Vector quantization implementation.
         
     | 
| 227 | 
         
            +
                Currently supports only euclidean distance.
         
     | 
| 228 | 
         
            +
                Args:
         
     | 
| 229 | 
         
            +
                    dim (int): Dimension
         
     | 
| 230 | 
         
            +
                    codebook_size (int): Codebook size
         
     | 
| 231 | 
         
            +
                    codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
         
     | 
| 232 | 
         
            +
                    decay (float): Decay for exponential moving average over the codebooks.
         
     | 
| 233 | 
         
            +
                    epsilon (float): Epsilon value for numerical stability.
         
     | 
| 234 | 
         
            +
                    kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
         
     | 
| 235 | 
         
            +
                    kmeans_iters (int): Number of iterations used for kmeans initialization.
         
     | 
| 236 | 
         
            +
                    threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
         
     | 
| 237 | 
         
            +
                        that have an exponential moving average cluster size less than the specified threshold with
         
     | 
| 238 | 
         
            +
                        randomly selected vector from the current batch.
         
     | 
| 239 | 
         
            +
                    commitment_weight (float): Weight for commitment loss.
         
     | 
| 240 | 
         
            +
                """
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                def __init__(
         
     | 
| 243 | 
         
            +
                    self,
         
     | 
| 244 | 
         
            +
                    dim: int,
         
     | 
| 245 | 
         
            +
                    codebook_size: int,
         
     | 
| 246 | 
         
            +
                    codebook_dim: tp.Optional[int] = None,
         
     | 
| 247 | 
         
            +
                    decay: float = 0.99,
         
     | 
| 248 | 
         
            +
                    epsilon: float = 1e-5,
         
     | 
| 249 | 
         
            +
                    kmeans_init: bool = True,
         
     | 
| 250 | 
         
            +
                    kmeans_iters: int = 50,
         
     | 
| 251 | 
         
            +
                    threshold_ema_dead_code: int = 2,
         
     | 
| 252 | 
         
            +
                    commitment_weight: float = 1.0,
         
     | 
| 253 | 
         
            +
                ):
         
     | 
| 254 | 
         
            +
                    super().__init__()
         
     | 
| 255 | 
         
            +
                    _codebook_dim: int = default(codebook_dim, dim)
         
     | 
| 256 | 
         
            +
             
     | 
| 257 | 
         
            +
                    requires_projection = _codebook_dim != dim
         
     | 
| 258 | 
         
            +
                    self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
         
     | 
| 259 | 
         
            +
                    self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
         
     | 
| 260 | 
         
            +
             
     | 
| 261 | 
         
            +
                    self.epsilon = epsilon
         
     | 
| 262 | 
         
            +
                    self.commitment_weight = commitment_weight
         
     | 
| 263 | 
         
            +
             
     | 
| 264 | 
         
            +
                    self._codebook = EuclideanCodebook(
         
     | 
| 265 | 
         
            +
                        dim=_codebook_dim,
         
     | 
| 266 | 
         
            +
                        codebook_size=codebook_size,
         
     | 
| 267 | 
         
            +
                        kmeans_init=kmeans_init,
         
     | 
| 268 | 
         
            +
                        kmeans_iters=kmeans_iters,
         
     | 
| 269 | 
         
            +
                        decay=decay,
         
     | 
| 270 | 
         
            +
                        epsilon=epsilon,
         
     | 
| 271 | 
         
            +
                        threshold_ema_dead_code=threshold_ema_dead_code,
         
     | 
| 272 | 
         
            +
                    )
         
     | 
| 273 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                @property
         
     | 
| 276 | 
         
            +
                def codebook(self):
         
     | 
| 277 | 
         
            +
                    return self._codebook.embed
         
     | 
| 278 | 
         
            +
             
     | 
| 279 | 
         
            +
                def encode(self, x):
         
     | 
| 280 | 
         
            +
                    x = rearrange(x, "b d n -> b n d")
         
     | 
| 281 | 
         
            +
                    x = self.project_in(x)
         
     | 
| 282 | 
         
            +
                    embed_in = self._codebook.encode(x)
         
     | 
| 283 | 
         
            +
                    return embed_in
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                def decode(self, embed_ind):
         
     | 
| 286 | 
         
            +
                    quantize = self._codebook.decode(embed_ind)
         
     | 
| 287 | 
         
            +
                    quantize = self.project_out(quantize)
         
     | 
| 288 | 
         
            +
                    quantize = rearrange(quantize, "b n d -> b d n")
         
     | 
| 289 | 
         
            +
                    return quantize
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                def forward(self, x):
         
     | 
| 292 | 
         
            +
                    device = x.device
         
     | 
| 293 | 
         
            +
                    x = rearrange(x, "b d n -> b n d")
         
     | 
| 294 | 
         
            +
                    x = self.project_in(x)
         
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
                    quantize, embed_ind = self._codebook(x)
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                    if self.training:
         
     | 
| 299 | 
         
            +
                        quantize = x + (quantize - x).detach()
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    loss = torch.tensor([0.0], device=device, requires_grad=self.training)
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    if self.training:
         
     | 
| 304 | 
         
            +
                        if self.commitment_weight > 0:
         
     | 
| 305 | 
         
            +
                            commit_loss = F.mse_loss(quantize.detach(), x)
         
     | 
| 306 | 
         
            +
                            loss = loss + commit_loss * self.commitment_weight
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                    quantize = self.project_out(quantize)
         
     | 
| 309 | 
         
            +
                    quantize = rearrange(quantize, "b n d -> b d n")
         
     | 
| 310 | 
         
            +
                    return quantize, embed_ind, loss
         
     | 
| 311 | 
         
            +
             
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
            class ResidualVectorQuantization(nn.Module):
         
     | 
| 314 | 
         
            +
                """Residual vector quantization implementation.
         
     | 
| 315 | 
         
            +
                Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
         
     | 
| 316 | 
         
            +
                """
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
            +
                def __init__(self, *, num_quantizers, **kwargs):
         
     | 
| 319 | 
         
            +
                    super().__init__()
         
     | 
| 320 | 
         
            +
                    self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                def forward(self, x, n_q: tp.Optional[int] = None):
         
     | 
| 323 | 
         
            +
                    quantized_out = 0.0
         
     | 
| 324 | 
         
            +
                    residual = x
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                    all_losses = []
         
     | 
| 327 | 
         
            +
                    all_indices = []
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
                    n_q = n_q or len(self.layers)
         
     | 
| 330 | 
         
            +
             
     | 
| 331 | 
         
            +
                    for layer in self.layers[:n_q]:
         
     | 
| 332 | 
         
            +
                        quantized, indices, loss = layer(residual)
         
     | 
| 333 | 
         
            +
                        residual = residual - quantized
         
     | 
| 334 | 
         
            +
                        quantized_out = quantized_out + quantized
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                        all_indices.append(indices)
         
     | 
| 337 | 
         
            +
                        all_losses.append(loss)
         
     | 
| 338 | 
         
            +
             
     | 
| 339 | 
         
            +
                    out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
         
     | 
| 340 | 
         
            +
                    return quantized_out, out_indices, out_losses
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
         
     | 
| 343 | 
         
            +
                    residual = x
         
     | 
| 344 | 
         
            +
                    all_indices = []
         
     | 
| 345 | 
         
            +
                    n_q = n_q or len(self.layers)
         
     | 
| 346 | 
         
            +
                    for layer in self.layers[:n_q]:
         
     | 
| 347 | 
         
            +
                        indices = layer.encode(residual)
         
     | 
| 348 | 
         
            +
                        quantized = layer.decode(indices)
         
     | 
| 349 | 
         
            +
                        residual = residual - quantized
         
     | 
| 350 | 
         
            +
                        all_indices.append(indices)
         
     | 
| 351 | 
         
            +
                    out_indices = torch.stack(all_indices)
         
     | 
| 352 | 
         
            +
                    return out_indices
         
     | 
| 353 | 
         
            +
             
     | 
| 354 | 
         
            +
                def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
         
     | 
| 355 | 
         
            +
                    quantized_out = torch.tensor(0.0, device=q_indices.device)
         
     | 
| 356 | 
         
            +
                    for i, indices in enumerate(q_indices):
         
     | 
| 357 | 
         
            +
                        layer = self.layers[i]
         
     | 
| 358 | 
         
            +
                        quantized = layer.decode(indices)
         
     | 
| 359 | 
         
            +
                        quantized_out = quantized_out + quantized
         
     | 
| 360 | 
         
            +
                    return quantized_out
         
     | 
    	
        higgs_audio/audio_processing/quantization/core_vq_lsx_version.py
    ADDED
    
    | 
         @@ -0,0 +1,431 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c)
         
     | 
| 2 | 
         
            +
            #
         
     | 
| 3 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 4 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 5 | 
         
            +
            # This implementation is inspired from
         
     | 
| 6 | 
         
            +
            # https://github.com/rosinality/vq-vae-2-pytorch/blob/master/vqvae.py and
         
     | 
| 7 | 
         
            +
            # https://github.com/clementchadebec/benchmark_VAE/blob/dfa0dcf6c79172df5d27769c09c860c42008baaa/src/pythae/models/vq_vae/vq_vae_utils.py#L81
         
     | 
| 8 | 
         
            +
            #
         
     | 
| 9 | 
         
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         
     | 
| 10 | 
         
            +
            # All rights reserved.
         
     | 
| 11 | 
         
            +
            #
         
     | 
| 12 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 13 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 14 | 
         
            +
            #
         
     | 
| 15 | 
         
            +
            # This implementation is inspired from
         
     | 
| 16 | 
         
            +
            # https://github.com/lucidrains/vector-quantize-pytorch
         
     | 
| 17 | 
         
            +
            # which is released under MIT License. Hereafter, the original license:
         
     | 
| 18 | 
         
            +
            # MIT License
         
     | 
| 19 | 
         
            +
            #
         
     | 
| 20 | 
         
            +
            # Copyright (c) 2020 Phil Wang
         
     | 
| 21 | 
         
            +
            #
         
     | 
| 22 | 
         
            +
            # Permission is hereby granted, free of charge, to any person obtaining a copy
         
     | 
| 23 | 
         
            +
            # of this software and associated documentation files (the "Software"), to deal
         
     | 
| 24 | 
         
            +
            # in the Software without restriction, including without limitation the rights
         
     | 
| 25 | 
         
            +
            # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         
     | 
| 26 | 
         
            +
            # copies of the Software, and to permit persons to whom the Software is
         
     | 
| 27 | 
         
            +
            # furnished to do so, subject to the following conditions:
         
     | 
| 28 | 
         
            +
            #
         
     | 
| 29 | 
         
            +
            # The above copyright notice and this permission notice shall be included in all
         
     | 
| 30 | 
         
            +
            # copies or substantial portions of the Software.
         
     | 
| 31 | 
         
            +
            #
         
     | 
| 32 | 
         
            +
            # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         
     | 
| 33 | 
         
            +
            # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         
     | 
| 34 | 
         
            +
            # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         
     | 
| 35 | 
         
            +
            # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         
     | 
| 36 | 
         
            +
            # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         
     | 
| 37 | 
         
            +
            # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         
     | 
| 38 | 
         
            +
            # SOFTWARE.
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
            """Core vector quantization implementation."""
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
            import typing as tp
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            from einops import rearrange
         
     | 
| 45 | 
         
            +
            import torch
         
     | 
| 46 | 
         
            +
            from torch import nn
         
     | 
| 47 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 48 | 
         
            +
            import torch.distributed as dist
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
            from .distrib import broadcast_tensors, is_distributed
         
     | 
| 51 | 
         
            +
            from .ddp_utils import SyncFunction
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def default(val: tp.Any, d: tp.Any) -> tp.Any:
         
     | 
| 55 | 
         
            +
                return val if val is not None else d
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def ema_inplace(moving_avg, new, decay: float):
         
     | 
| 59 | 
         
            +
                moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
         
     | 
| 63 | 
         
            +
                return (x + epsilon) / (x.sum() + n_categories * epsilon)
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            def uniform_init(*shape: int):
         
     | 
| 67 | 
         
            +
                t = torch.empty(shape)
         
     | 
| 68 | 
         
            +
                nn.init.kaiming_uniform_(t)
         
     | 
| 69 | 
         
            +
                return t
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
            def sample_vectors(samples, num: int):
         
     | 
| 73 | 
         
            +
                num_samples, device = samples.shape[0], samples.device
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
                if num_samples >= num:
         
     | 
| 76 | 
         
            +
                    indices = torch.randperm(num_samples, device=device)[:num]
         
     | 
| 77 | 
         
            +
                else:
         
     | 
| 78 | 
         
            +
                    indices = torch.randint(0, num_samples, (num,), device=device)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                return samples[indices]
         
     | 
| 81 | 
         
            +
             
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
            +
            def kmeans(
         
     | 
| 84 | 
         
            +
                samples,
         
     | 
| 85 | 
         
            +
                num_clusters: int,
         
     | 
| 86 | 
         
            +
                num_iters: int = 10,
         
     | 
| 87 | 
         
            +
                frames_to_use: int = 10_000,
         
     | 
| 88 | 
         
            +
                batch_size: int = 64,
         
     | 
| 89 | 
         
            +
            ):
         
     | 
| 90 | 
         
            +
                """
         
     | 
| 91 | 
         
            +
                Memory-efficient K-means clustering.
         
     | 
| 92 | 
         
            +
                Args:
         
     | 
| 93 | 
         
            +
                    samples (tensor): shape [N, D]
         
     | 
| 94 | 
         
            +
                    num_clusters (int): number of centroids.
         
     | 
| 95 | 
         
            +
                    num_iters (int): number of iterations.
         
     | 
| 96 | 
         
            +
                    frames_to_use (int): subsample size from total samples.
         
     | 
| 97 | 
         
            +
                    batch_size (int): batch size used in distance computation.
         
     | 
| 98 | 
         
            +
                Returns:
         
     | 
| 99 | 
         
            +
                    means: [num_clusters, D]
         
     | 
| 100 | 
         
            +
                    bins: [num_clusters] (number of points per cluster)
         
     | 
| 101 | 
         
            +
                """
         
     | 
| 102 | 
         
            +
                N, D = samples.shape
         
     | 
| 103 | 
         
            +
                dtype, device = samples.dtype, samples.device
         
     | 
| 104 | 
         
            +
             
     | 
| 105 | 
         
            +
                if frames_to_use < N:
         
     | 
| 106 | 
         
            +
                    indices = torch.randperm(N, device=device)[:frames_to_use]
         
     | 
| 107 | 
         
            +
                    samples = samples[indices]
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                means = sample_vectors(samples, num_clusters)
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                for _ in range(num_iters):
         
     | 
| 112 | 
         
            +
                    # Store cluster assignments
         
     | 
| 113 | 
         
            +
                    all_assignments = []
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                    for i in range(0, samples.shape[0], batch_size):
         
     | 
| 116 | 
         
            +
                        batch = samples[i : i + batch_size]  # [B, D]
         
     | 
| 117 | 
         
            +
                        dists = torch.cdist(batch, means, p=2)  # [B, C]
         
     | 
| 118 | 
         
            +
                        assignments = dists.argmin(dim=1)  # [B]
         
     | 
| 119 | 
         
            +
                        all_assignments.append(assignments)
         
     | 
| 120 | 
         
            +
             
     | 
| 121 | 
         
            +
                    buckets = torch.cat(all_assignments, dim=0)  # [N]
         
     | 
| 122 | 
         
            +
                    bins = torch.bincount(buckets, minlength=num_clusters)
         
     | 
| 123 | 
         
            +
                    zero_mask = bins == 0
         
     | 
| 124 | 
         
            +
                    bins_min_clamped = bins.masked_fill(zero_mask, 1)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    # Compute new means
         
     | 
| 127 | 
         
            +
                    new_means = torch.zeros_like(means)
         
     | 
| 128 | 
         
            +
                    for i in range(num_clusters):
         
     | 
| 129 | 
         
            +
                        mask = buckets == i
         
     | 
| 130 | 
         
            +
                        if mask.any():
         
     | 
| 131 | 
         
            +
                            new_means[i] = samples[mask].mean(dim=0)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    means = torch.where(zero_mask[:, None], means, new_means)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                return means, bins
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
            class EuclideanCodebook(nn.Module):
         
     | 
| 139 | 
         
            +
                """Codebook with Euclidean distance.
         
     | 
| 140 | 
         
            +
                Args:
         
     | 
| 141 | 
         
            +
                    dim (int): Dimension.
         
     | 
| 142 | 
         
            +
                    codebook_size (int): Codebook size.
         
     | 
| 143 | 
         
            +
                    kmeans_init (bool): Whether to use k-means to initialize the codebooks.
         
     | 
| 144 | 
         
            +
                        If set to true, run the k-means algorithm on the first training batch and use
         
     | 
| 145 | 
         
            +
                        the learned centroids as initialization.
         
     | 
| 146 | 
         
            +
                    kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
         
     | 
| 147 | 
         
            +
                    decay (float): Decay for exponential moving average over the codebooks.
         
     | 
| 148 | 
         
            +
                    epsilon (float): Epsilon value for numerical stability.
         
     | 
| 149 | 
         
            +
                    threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
         
     | 
| 150 | 
         
            +
                        that have an exponential moving average cluster size less than the specified threshold with
         
     | 
| 151 | 
         
            +
                        randomly selected vector from the current batch.
         
     | 
| 152 | 
         
            +
                """
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                def __init__(
         
     | 
| 155 | 
         
            +
                    self,
         
     | 
| 156 | 
         
            +
                    dim: int,
         
     | 
| 157 | 
         
            +
                    codebook_size: int,
         
     | 
| 158 | 
         
            +
                    kmeans_init: int = False,
         
     | 
| 159 | 
         
            +
                    kmeans_iters: int = 10,
         
     | 
| 160 | 
         
            +
                    decay: float = 0.99,
         
     | 
| 161 | 
         
            +
                    epsilon: float = 1e-5,
         
     | 
| 162 | 
         
            +
                    threshold_ema_dead_code: int = 2,
         
     | 
| 163 | 
         
            +
                ):
         
     | 
| 164 | 
         
            +
                    super().__init__()
         
     | 
| 165 | 
         
            +
                    self.decay = decay
         
     | 
| 166 | 
         
            +
                    init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
         
     | 
| 167 | 
         
            +
                    embed = init_fn(codebook_size, dim)
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    self.kmeans_iters = kmeans_iters
         
     | 
| 172 | 
         
            +
                    self.epsilon = epsilon
         
     | 
| 173 | 
         
            +
                    self.threshold_ema_dead_code = threshold_ema_dead_code
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                    # Flag variable to indicate whether the codebook is initialized
         
     | 
| 176 | 
         
            +
                    self.register_buffer("inited", torch.Tensor([not kmeans_init]))
         
     | 
| 177 | 
         
            +
                    # Runing EMA cluster size/count: N_i^t in eq. (6) in vqvae paper
         
     | 
| 178 | 
         
            +
                    self.register_buffer("cluster_size", torch.zeros(codebook_size))
         
     | 
| 179 | 
         
            +
                    # Codebook
         
     | 
| 180 | 
         
            +
                    self.register_buffer("embed", embed)
         
     | 
| 181 | 
         
            +
                    # EMA codebook: eq. (7) in vqvae paper
         
     | 
| 182 | 
         
            +
                    self.register_buffer("embed_avg", embed.clone())
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                @torch.jit.ignore
         
     | 
| 185 | 
         
            +
                def init_embed_(self, data):
         
     | 
| 186 | 
         
            +
                    """Initialize codebook.
         
     | 
| 187 | 
         
            +
                    Args:
         
     | 
| 188 | 
         
            +
                        data (tensor): [B * T, D].
         
     | 
| 189 | 
         
            +
                    """
         
     | 
| 190 | 
         
            +
                    if self.inited:
         
     | 
| 191 | 
         
            +
                        return
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                    ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
         
     | 
| 194 | 
         
            +
                    if dist.is_available() and dist.is_initialized():
         
     | 
| 195 | 
         
            +
                        # [B * T * world_size, D]
         
     | 
| 196 | 
         
            +
                        data = SyncFunction.apply(data)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
         
     | 
| 199 | 
         
            +
                    self.embed.data.copy_(embed)
         
     | 
| 200 | 
         
            +
                    self.embed_avg.data.copy_(embed.clone())
         
     | 
| 201 | 
         
            +
                    self.cluster_size.data.copy_(cluster_size)
         
     | 
| 202 | 
         
            +
                    self.inited.data.copy_(torch.Tensor([True]))
         
     | 
| 203 | 
         
            +
                    # Make sure all buffers across workers are in sync after initialization
         
     | 
| 204 | 
         
            +
                    broadcast_tensors(self.buffers())
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                def replace_(self, samples, mask):
         
     | 
| 207 | 
         
            +
                    modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
         
     | 
| 208 | 
         
            +
                    self.embed.data.copy_(modified_codebook)
         
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
                def expire_codes_(self, batch_samples):
         
     | 
| 211 | 
         
            +
                    if self.threshold_ema_dead_code == 0:
         
     | 
| 212 | 
         
            +
                        return
         
     | 
| 213 | 
         
            +
             
     | 
| 214 | 
         
            +
                    expired_codes = self.cluster_size < self.threshold_ema_dead_code
         
     | 
| 215 | 
         
            +
                    if not torch.any(expired_codes):
         
     | 
| 216 | 
         
            +
                        return
         
     | 
| 217 | 
         
            +
             
     | 
| 218 | 
         
            +
                    ## NOTE (snippet added by Songxiang Liu): gather data from all gpus
         
     | 
| 219 | 
         
            +
                    if is_distributed():
         
     | 
| 220 | 
         
            +
                        # [B * T * world_size, D]
         
     | 
| 221 | 
         
            +
                        batch_samples = SyncFunction.apply(batch_samples)
         
     | 
| 222 | 
         
            +
             
     | 
| 223 | 
         
            +
                    batch_samples = rearrange(batch_samples, "... d -> (...) d")
         
     | 
| 224 | 
         
            +
                    self.replace_(batch_samples, mask=expired_codes)
         
     | 
| 225 | 
         
            +
                    broadcast_tensors(self.buffers())
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                def preprocess(self, x):
         
     | 
| 228 | 
         
            +
                    x = rearrange(x, "... d -> (...) d")
         
     | 
| 229 | 
         
            +
                    return x
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                def quantize(self, x):
         
     | 
| 232 | 
         
            +
                    embed = self.embed.t()
         
     | 
| 233 | 
         
            +
                    dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
         
     | 
| 234 | 
         
            +
                    embed_ind = dist.max(dim=-1).indices
         
     | 
| 235 | 
         
            +
                    return embed_ind
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                def postprocess_emb(self, embed_ind, shape):
         
     | 
| 238 | 
         
            +
                    return embed_ind.view(*shape[:-1])
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                def dequantize(self, embed_ind):
         
     | 
| 241 | 
         
            +
                    quantize = F.embedding(embed_ind, self.embed)
         
     | 
| 242 | 
         
            +
                    return quantize
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                def encode(self, x):
         
     | 
| 245 | 
         
            +
                    shape = x.shape
         
     | 
| 246 | 
         
            +
                    # pre-process
         
     | 
| 247 | 
         
            +
                    x = self.preprocess(x)  # [B, T, D] -> [B*T, D]
         
     | 
| 248 | 
         
            +
                    # quantize
         
     | 
| 249 | 
         
            +
                    embed_ind = self.quantize(x)
         
     | 
| 250 | 
         
            +
                    # post-process
         
     | 
| 251 | 
         
            +
                    embed_ind = self.postprocess_emb(embed_ind, shape)
         
     | 
| 252 | 
         
            +
                    return embed_ind
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                def decode(self, embed_ind):
         
     | 
| 255 | 
         
            +
                    quantize = self.dequantize(embed_ind)
         
     | 
| 256 | 
         
            +
                    return quantize
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                def forward(self, x):
         
     | 
| 259 | 
         
            +
                    # shape: [B, T, D]
         
     | 
| 260 | 
         
            +
                    shape, dtype = x.shape, x.dtype
         
     | 
| 261 | 
         
            +
                    x = self.preprocess(x)  # [B, T, D] -> [B*T, D]
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    # Initialize codebook
         
     | 
| 264 | 
         
            +
                    self.init_embed_(x)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    embed_ind = self.quantize(x)  # [B*T,]
         
     | 
| 267 | 
         
            +
                    embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)  # [B*T, cb-size]
         
     | 
| 268 | 
         
            +
                    embed_ind = self.postprocess_emb(embed_ind, shape)  # [B, T]
         
     | 
| 269 | 
         
            +
                    quantize = self.dequantize(embed_ind)  # [B, T, D]
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                    if self.training:
         
     | 
| 272 | 
         
            +
                        ### Update codebook by EMA
         
     | 
| 273 | 
         
            +
                        embed_onehot_sum = embed_onehot.sum(0)  # [cb-size,]
         
     | 
| 274 | 
         
            +
                        embed_sum = x.t() @ embed_onehot  # [D, cb-size]
         
     | 
| 275 | 
         
            +
                        if is_distributed():
         
     | 
| 276 | 
         
            +
                            dist.all_reduce(embed_onehot_sum)
         
     | 
| 277 | 
         
            +
                            dist.all_reduce(embed_sum)
         
     | 
| 278 | 
         
            +
                        # Update ema cluster count N_i^t, eq. (6) in vqvae paper
         
     | 
| 279 | 
         
            +
                        self.cluster_size.data.mul_(self.decay).add_(embed_onehot_sum, alpha=1 - self.decay)
         
     | 
| 280 | 
         
            +
                        # Update ema embed: eq. (7) in vqvae paper
         
     | 
| 281 | 
         
            +
                        self.embed_avg.data.mul_(self.decay).add_(embed_sum.t(), alpha=1 - self.decay)
         
     | 
| 282 | 
         
            +
                        # apply laplace smoothing
         
     | 
| 283 | 
         
            +
                        n = self.cluster_size.sum()
         
     | 
| 284 | 
         
            +
                        cluster_size = (self.cluster_size + self.epsilon) / (n + self.codebook_size * self.epsilon) * n
         
     | 
| 285 | 
         
            +
                        # Update ema embed: eq. (8) in vqvae paper
         
     | 
| 286 | 
         
            +
                        embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
         
     | 
| 287 | 
         
            +
                        self.embed.data.copy_(embed_normalized)
         
     | 
| 288 | 
         
            +
             
     | 
| 289 | 
         
            +
                        # We do the expiry of code at that point as buffers are in sync
         
     | 
| 290 | 
         
            +
                        # and all the workers will take the same decision.
         
     | 
| 291 | 
         
            +
                        self.expire_codes_(x)
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                    return quantize, embed_ind
         
     | 
| 294 | 
         
            +
             
     | 
| 295 | 
         
            +
             
     | 
| 296 | 
         
            +
            class VectorQuantization(nn.Module):
         
     | 
| 297 | 
         
            +
                """Vector quantization implementation.
         
     | 
| 298 | 
         
            +
                Currently supports only euclidean distance.
         
     | 
| 299 | 
         
            +
                Args:
         
     | 
| 300 | 
         
            +
                    dim (int): Dimension
         
     | 
| 301 | 
         
            +
                    codebook_size (int): Codebook size
         
     | 
| 302 | 
         
            +
                    codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
         
     | 
| 303 | 
         
            +
                    decay (float): Decay for exponential moving average over the codebooks.
         
     | 
| 304 | 
         
            +
                    epsilon (float): Epsilon value for numerical stability.
         
     | 
| 305 | 
         
            +
                    kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
         
     | 
| 306 | 
         
            +
                    kmeans_iters (int): Number of iterations used for kmeans initialization.
         
     | 
| 307 | 
         
            +
                    threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
         
     | 
| 308 | 
         
            +
                        that have an exponential moving average cluster size less than the specified threshold with
         
     | 
| 309 | 
         
            +
                        randomly selected vector from the current batch.
         
     | 
| 310 | 
         
            +
                    commitment_weight (float): Weight for commitment loss.
         
     | 
| 311 | 
         
            +
                """
         
     | 
| 312 | 
         
            +
             
     | 
| 313 | 
         
            +
                def __init__(
         
     | 
| 314 | 
         
            +
                    self,
         
     | 
| 315 | 
         
            +
                    dim: int,
         
     | 
| 316 | 
         
            +
                    codebook_size: int,
         
     | 
| 317 | 
         
            +
                    codebook_dim: tp.Optional[int] = None,
         
     | 
| 318 | 
         
            +
                    decay: float = 0.99,
         
     | 
| 319 | 
         
            +
                    epsilon: float = 1e-5,
         
     | 
| 320 | 
         
            +
                    kmeans_init: bool = True,
         
     | 
| 321 | 
         
            +
                    kmeans_iters: int = 50,
         
     | 
| 322 | 
         
            +
                    threshold_ema_dead_code: int = 2,
         
     | 
| 323 | 
         
            +
                    commitment_weight: float = 1.0,
         
     | 
| 324 | 
         
            +
                ):
         
     | 
| 325 | 
         
            +
                    super().__init__()
         
     | 
| 326 | 
         
            +
                    _codebook_dim: int = default(codebook_dim, dim)
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                    requires_projection = _codebook_dim != dim
         
     | 
| 329 | 
         
            +
                    self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
         
     | 
| 330 | 
         
            +
                    self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
         
     | 
| 331 | 
         
            +
             
     | 
| 332 | 
         
            +
                    self.epsilon = epsilon
         
     | 
| 333 | 
         
            +
                    self.commitment_weight = commitment_weight
         
     | 
| 334 | 
         
            +
             
     | 
| 335 | 
         
            +
                    self._codebook = EuclideanCodebook(
         
     | 
| 336 | 
         
            +
                        dim=_codebook_dim,
         
     | 
| 337 | 
         
            +
                        codebook_size=codebook_size,
         
     | 
| 338 | 
         
            +
                        kmeans_init=kmeans_init,
         
     | 
| 339 | 
         
            +
                        kmeans_iters=kmeans_iters,
         
     | 
| 340 | 
         
            +
                        decay=decay,
         
     | 
| 341 | 
         
            +
                        epsilon=epsilon,
         
     | 
| 342 | 
         
            +
                        threshold_ema_dead_code=threshold_ema_dead_code,
         
     | 
| 343 | 
         
            +
                    )
         
     | 
| 344 | 
         
            +
                    self.codebook_size = codebook_size
         
     | 
| 345 | 
         
            +
             
     | 
| 346 | 
         
            +
                @property
         
     | 
| 347 | 
         
            +
                def codebook(self):
         
     | 
| 348 | 
         
            +
                    return self._codebook.embed
         
     | 
| 349 | 
         
            +
             
     | 
| 350 | 
         
            +
                def encode(self, x):
         
     | 
| 351 | 
         
            +
                    x = rearrange(x, "b d n -> b n d")
         
     | 
| 352 | 
         
            +
                    x = self.project_in(x)
         
     | 
| 353 | 
         
            +
                    embed_in = self._codebook.encode(x)
         
     | 
| 354 | 
         
            +
                    return embed_in
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                def decode(self, embed_ind):
         
     | 
| 357 | 
         
            +
                    quantize = self._codebook.decode(embed_ind)
         
     | 
| 358 | 
         
            +
                    quantize = self.project_out(quantize)
         
     | 
| 359 | 
         
            +
                    quantize = rearrange(quantize, "b n d -> b d n")
         
     | 
| 360 | 
         
            +
                    return quantize
         
     | 
| 361 | 
         
            +
             
     | 
| 362 | 
         
            +
                def forward(self, x):
         
     | 
| 363 | 
         
            +
                    device = x.device
         
     | 
| 364 | 
         
            +
                    x = x.transpose(1, 2).contiguous()  # [b d n] -> [b n d]
         
     | 
| 365 | 
         
            +
                    x = self.project_in(x)
         
     | 
| 366 | 
         
            +
             
     | 
| 367 | 
         
            +
                    quantize, embed_ind = self._codebook(x)
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                    if self.training:
         
     | 
| 370 | 
         
            +
                        quantize = x + (quantize - x).detach()
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    loss = torch.tensor([0.0], device=device, requires_grad=self.training)
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                    if self.training:
         
     | 
| 375 | 
         
            +
                        if self.commitment_weight > 0:
         
     | 
| 376 | 
         
            +
                            commit_loss = F.mse_loss(quantize.detach(), x)
         
     | 
| 377 | 
         
            +
                            loss = loss + commit_loss * self.commitment_weight
         
     | 
| 378 | 
         
            +
             
     | 
| 379 | 
         
            +
                    quantize = self.project_out(quantize)
         
     | 
| 380 | 
         
            +
                    quantize = quantize.transpose(1, 2).contiguous()  # [b n d] -> [b d n]
         
     | 
| 381 | 
         
            +
                    return quantize, embed_ind, loss
         
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
            class ResidualVectorQuantization(nn.Module):
         
     | 
| 385 | 
         
            +
                """Residual vector quantization implementation.
         
     | 
| 386 | 
         
            +
                Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
         
     | 
| 387 | 
         
            +
                """
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                def __init__(self, *, num_quantizers, **kwargs):
         
     | 
| 390 | 
         
            +
                    super().__init__()
         
     | 
| 391 | 
         
            +
                    self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
         
     | 
| 392 | 
         
            +
             
     | 
| 393 | 
         
            +
                def forward(self, x, n_q: tp.Optional[int] = None):
         
     | 
| 394 | 
         
            +
                    quantized_out = 0.0
         
     | 
| 395 | 
         
            +
                    residual = x
         
     | 
| 396 | 
         
            +
             
     | 
| 397 | 
         
            +
                    all_losses = []
         
     | 
| 398 | 
         
            +
                    all_indices = []
         
     | 
| 399 | 
         
            +
             
     | 
| 400 | 
         
            +
                    n_q = n_q or len(self.layers)
         
     | 
| 401 | 
         
            +
             
     | 
| 402 | 
         
            +
                    for layer in self.layers[:n_q]:
         
     | 
| 403 | 
         
            +
                        quantized, indices, loss = layer(residual)
         
     | 
| 404 | 
         
            +
                        residual = residual - quantized
         
     | 
| 405 | 
         
            +
                        quantized_out = quantized_out + quantized
         
     | 
| 406 | 
         
            +
             
     | 
| 407 | 
         
            +
                        all_indices.append(indices)
         
     | 
| 408 | 
         
            +
                        all_losses.append(loss)
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
         
     | 
| 411 | 
         
            +
                    return quantized_out, out_indices, out_losses
         
     | 
| 412 | 
         
            +
             
     | 
| 413 | 
         
            +
                def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
         
     | 
| 414 | 
         
            +
                    residual = x
         
     | 
| 415 | 
         
            +
                    all_indices = []
         
     | 
| 416 | 
         
            +
                    n_q = n_q or len(self.layers)
         
     | 
| 417 | 
         
            +
                    for layer in self.layers[:n_q]:
         
     | 
| 418 | 
         
            +
                        indices = layer.encode(residual)
         
     | 
| 419 | 
         
            +
                        quantized = layer.decode(indices)
         
     | 
| 420 | 
         
            +
                        residual = residual - quantized
         
     | 
| 421 | 
         
            +
                        all_indices.append(indices)
         
     | 
| 422 | 
         
            +
                    out_indices = torch.stack(all_indices)
         
     | 
| 423 | 
         
            +
                    return out_indices
         
     | 
| 424 | 
         
            +
             
     | 
| 425 | 
         
            +
                def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
         
     | 
| 426 | 
         
            +
                    quantized_out = torch.tensor(0.0, device=q_indices.device)
         
     | 
| 427 | 
         
            +
                    for i, indices in enumerate(q_indices):
         
     | 
| 428 | 
         
            +
                        layer = self.layers[i]
         
     | 
| 429 | 
         
            +
                        quantized = layer.decode(indices)
         
     | 
| 430 | 
         
            +
                        quantized_out = quantized_out + quantized
         
     | 
| 431 | 
         
            +
                    return quantized_out
         
     | 
    	
        higgs_audio/audio_processing/quantization/ddp_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,197 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import logging
         
     | 
| 2 | 
         
            +
            import random
         
     | 
| 3 | 
         
            +
            import subprocess
         
     | 
| 4 | 
         
            +
            from datetime import datetime
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import torch
         
     | 
| 8 | 
         
            +
            import torch.distributed as dist
         
     | 
| 9 | 
         
            +
            from torch.nn.parallel import DistributedDataParallel
         
     | 
| 10 | 
         
            +
            from torch.nn.parallel.distributed import _find_tensors
         
     | 
| 11 | 
         
            +
            import torch.optim
         
     | 
| 12 | 
         
            +
            import torch.utils.data
         
     | 
| 13 | 
         
            +
            from packaging import version
         
     | 
| 14 | 
         
            +
            from omegaconf import OmegaConf
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def set_random_seed(seed):
         
     | 
| 18 | 
         
            +
                random.seed(seed)
         
     | 
| 19 | 
         
            +
                np.random.seed(seed)
         
     | 
| 20 | 
         
            +
                torch.manual_seed(seed)
         
     | 
| 21 | 
         
            +
                torch.cuda.manual_seed_all(seed)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            def is_logging_process():
         
     | 
| 25 | 
         
            +
                return not dist.is_initialized() or dist.get_rank() == 0
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def get_logger(cfg, name=None):
         
     | 
| 29 | 
         
            +
                # log_file_path is used when unit testing
         
     | 
| 30 | 
         
            +
                if is_logging_process():
         
     | 
| 31 | 
         
            +
                    logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_config, resolve=True))
         
     | 
| 32 | 
         
            +
                    return logging.getLogger(name)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # from https://github.com/Lightning-AI/lightning-bolts/blob/5d61197cd2f491f69e238137a5edabe80ae14ad9/pl_bolts/models/self_supervised/simclr/simclr_module.py#L20
         
     | 
| 36 | 
         
            +
            class SyncFunction(torch.autograd.Function):
         
     | 
| 37 | 
         
            +
                @staticmethod
         
     | 
| 38 | 
         
            +
                # @torch.no_grad()
         
     | 
| 39 | 
         
            +
                def forward(ctx, tensor):
         
     | 
| 40 | 
         
            +
                    ctx.batch_size = tensor.shape[0]
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
                    gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                    torch.distributed.all_gather(gathered_tensor, tensor)
         
     | 
| 45 | 
         
            +
                    gathered_tensor = torch.cat(gathered_tensor, 0)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                    return gathered_tensor
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                @staticmethod
         
     | 
| 50 | 
         
            +
                def backward(ctx, grad_output):
         
     | 
| 51 | 
         
            +
                    grad_input = grad_output.clone()
         
     | 
| 52 | 
         
            +
                    torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                    idx_from = torch.distributed.get_rank() * ctx.batch_size
         
     | 
| 55 | 
         
            +
                    idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
         
     | 
| 56 | 
         
            +
                    return grad_input[idx_from:idx_to]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
            def get_timestamp():
         
     | 
| 60 | 
         
            +
                return datetime.now().strftime("%y%m%d-%H%M%S")
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def get_commit_hash():
         
     | 
| 64 | 
         
            +
                message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
         
     | 
| 65 | 
         
            +
                return message.strip().decode("utf-8")
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            class DDP(DistributedDataParallel):
         
     | 
| 69 | 
         
            +
                """
         
     | 
| 70 | 
         
            +
                Override the forward call in lightning so it goes to training and validation step respectively
         
     | 
| 71 | 
         
            +
                """
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                def forward(self, *inputs, **kwargs):  # pragma: no cover
         
     | 
| 74 | 
         
            +
                    if version.parse(torch.__version__[:6]) < version.parse("1.11"):
         
     | 
| 75 | 
         
            +
                        self._sync_params()
         
     | 
| 76 | 
         
            +
                        inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
         
     | 
| 77 | 
         
            +
                        assert len(self.device_ids) == 1
         
     | 
| 78 | 
         
            +
                        if self.module.training:
         
     | 
| 79 | 
         
            +
                            output = self.module.training_step(*inputs[0], **kwargs[0])
         
     | 
| 80 | 
         
            +
                        elif self.module.testing:
         
     | 
| 81 | 
         
            +
                            output = self.module.test_step(*inputs[0], **kwargs[0])
         
     | 
| 82 | 
         
            +
                        else:
         
     | 
| 83 | 
         
            +
                            output = self.module.validation_step(*inputs[0], **kwargs[0])
         
     | 
| 84 | 
         
            +
                        if torch.is_grad_enabled():
         
     | 
| 85 | 
         
            +
                            # We'll return the output object verbatim since it is a freeform
         
     | 
| 86 | 
         
            +
                            # object. We need to find any tensors in this object, though,
         
     | 
| 87 | 
         
            +
                            # because we need to figure out which parameters were used during
         
     | 
| 88 | 
         
            +
                            # this forward pass, to ensure we short circuit reduction for any
         
     | 
| 89 | 
         
            +
                            # unused parameters. Only if `find_unused_parameters` is set.
         
     | 
| 90 | 
         
            +
                            if self.find_unused_parameters:
         
     | 
| 91 | 
         
            +
                                self.reducer.prepare_for_backward(list(_find_tensors(output)))
         
     | 
| 92 | 
         
            +
                            else:
         
     | 
| 93 | 
         
            +
                                self.reducer.prepare_for_backward([])
         
     | 
| 94 | 
         
            +
                    else:
         
     | 
| 95 | 
         
            +
                        from torch.nn.parallel.distributed import (
         
     | 
| 96 | 
         
            +
                            logging,
         
     | 
| 97 | 
         
            +
                            Join,
         
     | 
| 98 | 
         
            +
                            _DDPSink,
         
     | 
| 99 | 
         
            +
                            _tree_flatten_with_rref,
         
     | 
| 100 | 
         
            +
                            _tree_unflatten_with_rref,
         
     | 
| 101 | 
         
            +
                        )
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                        with torch.autograd.profiler.record_function("DistributedDataParallel.forward"):
         
     | 
| 104 | 
         
            +
                            if torch.is_grad_enabled() and self.require_backward_grad_sync:
         
     | 
| 105 | 
         
            +
                                self.logger.set_runtime_stats_and_log()
         
     | 
| 106 | 
         
            +
                                self.num_iterations += 1
         
     | 
| 107 | 
         
            +
                                self.reducer.prepare_for_forward()
         
     | 
| 108 | 
         
            +
             
     | 
| 109 | 
         
            +
                            # Notify the join context that this process has not joined, if
         
     | 
| 110 | 
         
            +
                            # needed
         
     | 
| 111 | 
         
            +
                            work = Join.notify_join_context(self)
         
     | 
| 112 | 
         
            +
                            if work:
         
     | 
| 113 | 
         
            +
                                self.reducer._set_forward_pass_work_handle(work, self._divide_by_initial_world_size)
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                            # Calling _rebuild_buckets before forward compuation,
         
     | 
| 116 | 
         
            +
                            # It may allocate new buckets before deallocating old buckets
         
     | 
| 117 | 
         
            +
                            # inside _rebuild_buckets. To save peak memory usage,
         
     | 
| 118 | 
         
            +
                            # call _rebuild_buckets before the peak memory usage increases
         
     | 
| 119 | 
         
            +
                            # during forward computation.
         
     | 
| 120 | 
         
            +
                            # This should be called only once during whole training period.
         
     | 
| 121 | 
         
            +
                            if torch.is_grad_enabled() and self.reducer._rebuild_buckets():
         
     | 
| 122 | 
         
            +
                                logging.info("Reducer buckets have been rebuilt in this iteration.")
         
     | 
| 123 | 
         
            +
                                self._has_rebuilt_buckets = True
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                            # sync params according to location (before/after forward) user
         
     | 
| 126 | 
         
            +
                            # specified as part of hook, if hook was specified.
         
     | 
| 127 | 
         
            +
                            buffer_hook_registered = hasattr(self, "buffer_hook")
         
     | 
| 128 | 
         
            +
                            if self._check_sync_bufs_pre_fwd():
         
     | 
| 129 | 
         
            +
                                self._sync_buffers()
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
                            if self._join_config.enable:
         
     | 
| 132 | 
         
            +
                                # Notify joined ranks whether they should sync in backwards pass or not.
         
     | 
| 133 | 
         
            +
                                self._check_global_requires_backward_grad_sync(is_joined_rank=False)
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                            inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
         
     | 
| 136 | 
         
            +
                            if self.module.training:
         
     | 
| 137 | 
         
            +
                                output = self.module.training_step(*inputs[0], **kwargs[0])
         
     | 
| 138 | 
         
            +
                            elif self.module.testing:
         
     | 
| 139 | 
         
            +
                                output = self.module.test_step(*inputs[0], **kwargs[0])
         
     | 
| 140 | 
         
            +
                            else:
         
     | 
| 141 | 
         
            +
                                output = self.module.validation_step(*inputs[0], **kwargs[0])
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
            +
                            # sync params according to location (before/after forward) user
         
     | 
| 144 | 
         
            +
                            # specified as part of hook, if hook was specified.
         
     | 
| 145 | 
         
            +
                            if self._check_sync_bufs_post_fwd():
         
     | 
| 146 | 
         
            +
                                self._sync_buffers()
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                            if torch.is_grad_enabled() and self.require_backward_grad_sync:
         
     | 
| 149 | 
         
            +
                                self.require_forward_param_sync = True
         
     | 
| 150 | 
         
            +
                                # We'll return the output object verbatim since it is a freeform
         
     | 
| 151 | 
         
            +
                                # object. We need to find any tensors in this object, though,
         
     | 
| 152 | 
         
            +
                                # because we need to figure out which parameters were used during
         
     | 
| 153 | 
         
            +
                                # this forward pass, to ensure we short circuit reduction for any
         
     | 
| 154 | 
         
            +
                                # unused parameters. Only if `find_unused_parameters` is set.
         
     | 
| 155 | 
         
            +
                                if self.find_unused_parameters and not self.static_graph:
         
     | 
| 156 | 
         
            +
                                    # Do not need to populate this for static graph.
         
     | 
| 157 | 
         
            +
                                    self.reducer.prepare_for_backward(list(_find_tensors(output)))
         
     | 
| 158 | 
         
            +
                                else:
         
     | 
| 159 | 
         
            +
                                    self.reducer.prepare_for_backward([])
         
     | 
| 160 | 
         
            +
                            else:
         
     | 
| 161 | 
         
            +
                                self.require_forward_param_sync = False
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
                        # TODO: DDPSink is currently enabled for unused parameter detection and
         
     | 
| 164 | 
         
            +
                        # static graph training for first iteration.
         
     | 
| 165 | 
         
            +
                        if (self.find_unused_parameters and not self.static_graph) or (
         
     | 
| 166 | 
         
            +
                            self.static_graph and self.num_iterations == 1
         
     | 
| 167 | 
         
            +
                        ):
         
     | 
| 168 | 
         
            +
                            state_dict = {
         
     | 
| 169 | 
         
            +
                                "static_graph": self.static_graph,
         
     | 
| 170 | 
         
            +
                                "num_iterations": self.num_iterations,
         
     | 
| 171 | 
         
            +
                            }
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                            output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref(output)
         
     | 
| 174 | 
         
            +
                            output_placeholders = [None for _ in range(len(output_tensor_list))]
         
     | 
| 175 | 
         
            +
                            # Do not touch tensors that have no grad_fn, which can cause issues
         
     | 
| 176 | 
         
            +
                            # such as https://github.com/pytorch/pytorch/issues/60733
         
     | 
| 177 | 
         
            +
                            for i, output in enumerate(output_tensor_list):
         
     | 
| 178 | 
         
            +
                                if torch.is_tensor(output) and output.grad_fn is None:
         
     | 
| 179 | 
         
            +
                                    output_placeholders[i] = output
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                            # When find_unused_parameters=True, makes tensors which require grad
         
     | 
| 182 | 
         
            +
                            # run through the DDPSink backward pass. When not all outputs are
         
     | 
| 183 | 
         
            +
                            # used in loss, this makes those corresponding tensors receive
         
     | 
| 184 | 
         
            +
                            # undefined gradient which the reducer then handles to ensure
         
     | 
| 185 | 
         
            +
                            # param.grad field is not touched and we don't error out.
         
     | 
| 186 | 
         
            +
                            passthrough_tensor_list = _DDPSink.apply(
         
     | 
| 187 | 
         
            +
                                self.reducer,
         
     | 
| 188 | 
         
            +
                                state_dict,
         
     | 
| 189 | 
         
            +
                                *output_tensor_list,
         
     | 
| 190 | 
         
            +
                            )
         
     | 
| 191 | 
         
            +
                            for i in range(len(output_placeholders)):
         
     | 
| 192 | 
         
            +
                                if output_placeholders[i] is None:
         
     | 
| 193 | 
         
            +
                                    output_placeholders[i] = passthrough_tensor_list[i]
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                            # Reconstruct output data structure.
         
     | 
| 196 | 
         
            +
                            output = _tree_unflatten_with_rref(output_placeholders, treespec, output_is_rref)
         
     | 
| 197 | 
         
            +
                    return output
         
     | 
    	
        higgs_audio/audio_processing/quantization/distrib.py
    ADDED
    
    | 
         @@ -0,0 +1,123 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            """Torch distributed utilities."""
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            import typing as tp
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            import torch
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            def rank():
         
     | 
| 15 | 
         
            +
                if torch.distributed.is_initialized():
         
     | 
| 16 | 
         
            +
                    return torch.distributed.get_rank()
         
     | 
| 17 | 
         
            +
                else:
         
     | 
| 18 | 
         
            +
                    return 0
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            def world_size():
         
     | 
| 22 | 
         
            +
                if torch.distributed.is_initialized():
         
     | 
| 23 | 
         
            +
                    return torch.distributed.get_world_size()
         
     | 
| 24 | 
         
            +
                else:
         
     | 
| 25 | 
         
            +
                    return 1
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            def is_distributed():
         
     | 
| 29 | 
         
            +
                return world_size() > 1
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            def all_reduce(tensor: torch.Tensor, op=torch.distributed.ReduceOp.SUM):
         
     | 
| 33 | 
         
            +
                if is_distributed():
         
     | 
| 34 | 
         
            +
                    return torch.distributed.all_reduce(tensor, op)
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            def _is_complex_or_float(tensor):
         
     | 
| 38 | 
         
            +
                return torch.is_floating_point(tensor) or torch.is_complex(tensor)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
            def _check_number_of_params(params: tp.List[torch.Tensor]):
         
     | 
| 42 | 
         
            +
                # utility function to check that the number of params in all workers is the same,
         
     | 
| 43 | 
         
            +
                # and thus avoid a deadlock with distributed all reduce.
         
     | 
| 44 | 
         
            +
                if not is_distributed() or not params:
         
     | 
| 45 | 
         
            +
                    return
         
     | 
| 46 | 
         
            +
                # print('params[0].device ', params[0].device)
         
     | 
| 47 | 
         
            +
                tensor = torch.tensor([len(params)], device=params[0].device, dtype=torch.long)
         
     | 
| 48 | 
         
            +
                all_reduce(tensor)
         
     | 
| 49 | 
         
            +
                if tensor.item() != len(params) * world_size():
         
     | 
| 50 | 
         
            +
                    # If not all the workers have the same number, for at least one of them,
         
     | 
| 51 | 
         
            +
                    # this inequality will be verified.
         
     | 
| 52 | 
         
            +
                    raise RuntimeError(
         
     | 
| 53 | 
         
            +
                        f"Mismatch in number of params: ours is {len(params)}, at least one worker has a different one."
         
     | 
| 54 | 
         
            +
                    )
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            def broadcast_tensors(tensors: tp.Iterable[torch.Tensor], src: int = 0):
         
     | 
| 58 | 
         
            +
                """Broadcast the tensors from the given parameters to all workers.
         
     | 
| 59 | 
         
            +
                This can be used to ensure that all workers have the same model to start with.
         
     | 
| 60 | 
         
            +
                """
         
     | 
| 61 | 
         
            +
                if not is_distributed():
         
     | 
| 62 | 
         
            +
                    return
         
     | 
| 63 | 
         
            +
                tensors = [tensor for tensor in tensors if _is_complex_or_float(tensor)]
         
     | 
| 64 | 
         
            +
                _check_number_of_params(tensors)
         
     | 
| 65 | 
         
            +
                handles = []
         
     | 
| 66 | 
         
            +
                for tensor in tensors:
         
     | 
| 67 | 
         
            +
                    handle = torch.distributed.broadcast(tensor.data, src=src, async_op=True)
         
     | 
| 68 | 
         
            +
                    handles.append(handle)
         
     | 
| 69 | 
         
            +
                for handle in handles:
         
     | 
| 70 | 
         
            +
                    handle.wait()
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def sync_buffer(buffers, average=True):
         
     | 
| 74 | 
         
            +
                """
         
     | 
| 75 | 
         
            +
                Sync grad for buffers. If average is False, broadcast instead of averaging.
         
     | 
| 76 | 
         
            +
                """
         
     | 
| 77 | 
         
            +
                if not is_distributed():
         
     | 
| 78 | 
         
            +
                    return
         
     | 
| 79 | 
         
            +
                handles = []
         
     | 
| 80 | 
         
            +
                for buffer in buffers:
         
     | 
| 81 | 
         
            +
                    if torch.is_floating_point(buffer.data):
         
     | 
| 82 | 
         
            +
                        if average:
         
     | 
| 83 | 
         
            +
                            handle = torch.distributed.all_reduce(buffer.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
         
     | 
| 84 | 
         
            +
                        else:
         
     | 
| 85 | 
         
            +
                            handle = torch.distributed.broadcast(buffer.data, src=0, async_op=True)
         
     | 
| 86 | 
         
            +
                        handles.append((buffer, handle))
         
     | 
| 87 | 
         
            +
                for buffer, handle in handles:
         
     | 
| 88 | 
         
            +
                    handle.wait()
         
     | 
| 89 | 
         
            +
                    if average:
         
     | 
| 90 | 
         
            +
                        buffer.data /= world_size
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
            def sync_grad(params):
         
     | 
| 94 | 
         
            +
                """
         
     | 
| 95 | 
         
            +
                Simpler alternative to DistributedDataParallel, that doesn't rely
         
     | 
| 96 | 
         
            +
                on any black magic. For simple models it can also be as fast.
         
     | 
| 97 | 
         
            +
                Just call this on your model parameters after the call to backward!
         
     | 
| 98 | 
         
            +
                """
         
     | 
| 99 | 
         
            +
                if not is_distributed():
         
     | 
| 100 | 
         
            +
                    return
         
     | 
| 101 | 
         
            +
                handles = []
         
     | 
| 102 | 
         
            +
                for p in params:
         
     | 
| 103 | 
         
            +
                    if p.grad is not None:
         
     | 
| 104 | 
         
            +
                        handle = torch.distributed.all_reduce(p.grad.data, op=torch.distributed.ReduceOp.SUM, async_op=True)
         
     | 
| 105 | 
         
            +
                        handles.append((p, handle))
         
     | 
| 106 | 
         
            +
                for p, handle in handles:
         
     | 
| 107 | 
         
            +
                    handle.wait()
         
     | 
| 108 | 
         
            +
                    p.grad.data /= world_size()
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
            def average_metrics(metrics: tp.Dict[str, float], count=1.0):
         
     | 
| 112 | 
         
            +
                """Average a dictionary of metrics across all workers, using the optional
         
     | 
| 113 | 
         
            +
                `count` as unormalized weight.
         
     | 
| 114 | 
         
            +
                """
         
     | 
| 115 | 
         
            +
                if not is_distributed():
         
     | 
| 116 | 
         
            +
                    return metrics
         
     | 
| 117 | 
         
            +
                keys, values = zip(*metrics.items())
         
     | 
| 118 | 
         
            +
                device = "cuda" if torch.cuda.is_available() else "cpu"
         
     | 
| 119 | 
         
            +
                tensor = torch.tensor(list(values) + [1], device=device, dtype=torch.float32)
         
     | 
| 120 | 
         
            +
                tensor *= count
         
     | 
| 121 | 
         
            +
                all_reduce(tensor)
         
     | 
| 122 | 
         
            +
                averaged = (tensor[:-1] / tensor[-1]).cpu().tolist()
         
     | 
| 123 | 
         
            +
                return dict(zip(keys, averaged))
         
     | 
    	
        higgs_audio/audio_processing/quantization/vq.py
    ADDED
    
    | 
         @@ -0,0 +1,116 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Copyright (c) Meta Platforms, Inc. and affiliates.
         
     | 
| 2 | 
         
            +
            # All rights reserved.
         
     | 
| 3 | 
         
            +
            #
         
     | 
| 4 | 
         
            +
            # This source code is licensed under the license found in the
         
     | 
| 5 | 
         
            +
            # LICENSE file in the root directory of this source tree.
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            """Residual vector quantizer implementation."""
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 10 | 
         
            +
            import math
         
     | 
| 11 | 
         
            +
            import typing as tp
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            import torch
         
     | 
| 14 | 
         
            +
            from torch import nn
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            # from .core_vq import ResidualVectorQuantization
         
     | 
| 17 | 
         
            +
            from .core_vq_lsx_version import ResidualVectorQuantization
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            @dataclass
         
     | 
| 21 | 
         
            +
            class QuantizedResult:
         
     | 
| 22 | 
         
            +
                quantized: torch.Tensor
         
     | 
| 23 | 
         
            +
                codes: torch.Tensor
         
     | 
| 24 | 
         
            +
                bandwidth: torch.Tensor  # bandwidth in kb/s used, per batch item.
         
     | 
| 25 | 
         
            +
                penalty: tp.Optional[torch.Tensor] = None
         
     | 
| 26 | 
         
            +
                metrics: dict = field(default_factory=dict)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            class ResidualVectorQuantizer(nn.Module):
         
     | 
| 30 | 
         
            +
                """Residual Vector Quantizer.
         
     | 
| 31 | 
         
            +
                Args:
         
     | 
| 32 | 
         
            +
                    dimension (int): Dimension of the codebooks.
         
     | 
| 33 | 
         
            +
                    n_q (int): Number of residual vector quantizers used.
         
     | 
| 34 | 
         
            +
                    bins (int): Codebook size.
         
     | 
| 35 | 
         
            +
                    decay (float): Decay for exponential moving average over the codebooks.
         
     | 
| 36 | 
         
            +
                    kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
         
     | 
| 37 | 
         
            +
                    kmeans_iters (int): Number of iterations used for kmeans initialization.
         
     | 
| 38 | 
         
            +
                    threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
         
     | 
| 39 | 
         
            +
                        that have an exponential moving average cluster size less than the specified threshold with
         
     | 
| 40 | 
         
            +
                        randomly selected vector from the current batch.
         
     | 
| 41 | 
         
            +
                """
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                def __init__(
         
     | 
| 44 | 
         
            +
                    self,
         
     | 
| 45 | 
         
            +
                    dimension: int = 256,
         
     | 
| 46 | 
         
            +
                    codebook_dim: int = None,
         
     | 
| 47 | 
         
            +
                    n_q: int = 8,
         
     | 
| 48 | 
         
            +
                    bins: int = 1024,
         
     | 
| 49 | 
         
            +
                    decay: float = 0.99,
         
     | 
| 50 | 
         
            +
                    kmeans_init: bool = True,
         
     | 
| 51 | 
         
            +
                    kmeans_iters: int = 50,
         
     | 
| 52 | 
         
            +
                    threshold_ema_dead_code: int = 2,
         
     | 
| 53 | 
         
            +
                ):
         
     | 
| 54 | 
         
            +
                    super().__init__()
         
     | 
| 55 | 
         
            +
                    self.n_q = n_q
         
     | 
| 56 | 
         
            +
                    self.dimension = dimension
         
     | 
| 57 | 
         
            +
                    self.codebook_dim = codebook_dim
         
     | 
| 58 | 
         
            +
                    self.bins = bins
         
     | 
| 59 | 
         
            +
                    self.decay = decay
         
     | 
| 60 | 
         
            +
                    self.kmeans_init = kmeans_init
         
     | 
| 61 | 
         
            +
                    self.kmeans_iters = kmeans_iters
         
     | 
| 62 | 
         
            +
                    self.threshold_ema_dead_code = threshold_ema_dead_code
         
     | 
| 63 | 
         
            +
                    self.vq = ResidualVectorQuantization(
         
     | 
| 64 | 
         
            +
                        dim=self.dimension,
         
     | 
| 65 | 
         
            +
                        codebook_dim=self.codebook_dim,
         
     | 
| 66 | 
         
            +
                        codebook_size=self.bins,
         
     | 
| 67 | 
         
            +
                        num_quantizers=self.n_q,
         
     | 
| 68 | 
         
            +
                        decay=self.decay,
         
     | 
| 69 | 
         
            +
                        kmeans_init=self.kmeans_init,
         
     | 
| 70 | 
         
            +
                        kmeans_iters=self.kmeans_iters,
         
     | 
| 71 | 
         
            +
                        threshold_ema_dead_code=self.threshold_ema_dead_code,
         
     | 
| 72 | 
         
            +
                    )
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
                def forward(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None):  # -> QuantizedResult:
         
     | 
| 75 | 
         
            +
                    """Residual vector quantization on the given input tensor.
         
     | 
| 76 | 
         
            +
                    Args:
         
     | 
| 77 | 
         
            +
                        x (torch.Tensor): Input tensor.
         
     | 
| 78 | 
         
            +
                        sample_rate (int): Sample rate of the input tensor.
         
     | 
| 79 | 
         
            +
                        bandwidth (float): Target bandwidth.
         
     | 
| 80 | 
         
            +
                    Returns:
         
     | 
| 81 | 
         
            +
                        QuantizedResult:
         
     | 
| 82 | 
         
            +
                            The quantized (or approximately quantized) representation with
         
     | 
| 83 | 
         
            +
                            the associated bandwidth and any penalty term for the loss.
         
     | 
| 84 | 
         
            +
                    """
         
     | 
| 85 | 
         
            +
                    bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
         
     | 
| 86 | 
         
            +
                    n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
         
     | 
| 87 | 
         
            +
                    quantized, codes, commit_loss = self.vq(x, n_q=n_q)
         
     | 
| 88 | 
         
            +
                    bw = torch.tensor(n_q * bw_per_q).to(x)
         
     | 
| 89 | 
         
            +
                    return quantized, codes, bw, torch.mean(commit_loss)
         
     | 
| 90 | 
         
            +
                    # return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                def get_num_quantizers_for_bandwidth(self, sample_rate: int, bandwidth: tp.Optional[float] = None) -> int:
         
     | 
| 93 | 
         
            +
                    """Return n_q based on specified target bandwidth."""
         
     | 
| 94 | 
         
            +
                    bw_per_q = self.get_bandwidth_per_quantizer(sample_rate)
         
     | 
| 95 | 
         
            +
                    n_q = self.n_q
         
     | 
| 96 | 
         
            +
                    if bandwidth and bandwidth > 0.0:
         
     | 
| 97 | 
         
            +
                        n_q = int(max(1, math.floor(bandwidth / bw_per_q)))
         
     | 
| 98 | 
         
            +
                    return n_q
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                def get_bandwidth_per_quantizer(self, sample_rate: int):
         
     | 
| 101 | 
         
            +
                    """Return bandwidth per quantizer for a given input sample rate."""
         
     | 
| 102 | 
         
            +
                    return math.log2(self.bins) * sample_rate / 1000
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                def encode(self, x: torch.Tensor, sample_rate: int, bandwidth: tp.Optional[float] = None) -> torch.Tensor:
         
     | 
| 105 | 
         
            +
                    """Encode a given input tensor with the specified sample rate at the given bandwidth.
         
     | 
| 106 | 
         
            +
                    The RVQ encode method sets the appropriate number of quantizer to use
         
     | 
| 107 | 
         
            +
                    and returns indices for each quantizer.
         
     | 
| 108 | 
         
            +
                    """
         
     | 
| 109 | 
         
            +
                    n_q = self.get_num_quantizers_for_bandwidth(sample_rate, bandwidth)
         
     | 
| 110 | 
         
            +
                    codes = self.vq.encode(x, n_q=n_q)
         
     | 
| 111 | 
         
            +
                    return codes
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def decode(self, codes: torch.Tensor) -> torch.Tensor:
         
     | 
| 114 | 
         
            +
                    """Decode the given codes to the quantized representation."""
         
     | 
| 115 | 
         
            +
                    quantized = self.vq.decode(codes)
         
     | 
| 116 | 
         
            +
                    return quantized
         
     | 
    	
        higgs_audio/audio_processing/semantic_module.py
    ADDED
    
    | 
         @@ -0,0 +1,310 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Based on code from: https://github.com/zhenye234/xcodec
         
     | 
| 2 | 
         
            +
            # Licensed under MIT License
         
     | 
| 3 | 
         
            +
            # Modifications by BosonAI
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            class Conv1d1x1(nn.Conv1d):
         
     | 
| 10 | 
         
            +
                """1x1 Conv1d."""
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def __init__(self, in_channels, out_channels, bias=True):
         
     | 
| 13 | 
         
            +
                    super(Conv1d1x1, self).__init__(in_channels, out_channels, kernel_size=1, bias=bias)
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            class Conv1d(nn.Module):
         
     | 
| 17 | 
         
            +
                def __init__(
         
     | 
| 18 | 
         
            +
                    self,
         
     | 
| 19 | 
         
            +
                    in_channels: int,
         
     | 
| 20 | 
         
            +
                    out_channels: int,
         
     | 
| 21 | 
         
            +
                    kernel_size: int,
         
     | 
| 22 | 
         
            +
                    stride: int = 1,
         
     | 
| 23 | 
         
            +
                    padding: int = -1,
         
     | 
| 24 | 
         
            +
                    dilation: int = 1,
         
     | 
| 25 | 
         
            +
                    groups: int = 1,
         
     | 
| 26 | 
         
            +
                    bias: bool = True,
         
     | 
| 27 | 
         
            +
                ):
         
     | 
| 28 | 
         
            +
                    super().__init__()
         
     | 
| 29 | 
         
            +
                    self.in_channels = in_channels
         
     | 
| 30 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 31 | 
         
            +
                    self.kernel_size = kernel_size
         
     | 
| 32 | 
         
            +
                    if padding < 0:
         
     | 
| 33 | 
         
            +
                        padding = (kernel_size - 1) // 2 * dilation
         
     | 
| 34 | 
         
            +
                    self.dilation = dilation
         
     | 
| 35 | 
         
            +
                    self.conv = nn.Conv1d(
         
     | 
| 36 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 37 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 38 | 
         
            +
                        kernel_size=kernel_size,
         
     | 
| 39 | 
         
            +
                        stride=stride,
         
     | 
| 40 | 
         
            +
                        padding=padding,
         
     | 
| 41 | 
         
            +
                        dilation=dilation,
         
     | 
| 42 | 
         
            +
                        groups=groups,
         
     | 
| 43 | 
         
            +
                        bias=bias,
         
     | 
| 44 | 
         
            +
                    )
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def forward(self, x):
         
     | 
| 47 | 
         
            +
                    """
         
     | 
| 48 | 
         
            +
                    Args:
         
     | 
| 49 | 
         
            +
                        x (Tensor): Float tensor variable with the shape  (B, C, T).
         
     | 
| 50 | 
         
            +
                    Returns:
         
     | 
| 51 | 
         
            +
                        Tensor: Float tensor variable with the shape (B, C, T).
         
     | 
| 52 | 
         
            +
                    """
         
     | 
| 53 | 
         
            +
                    x = self.conv(x)
         
     | 
| 54 | 
         
            +
                    return x
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
            class ResidualUnit(nn.Module):
         
     | 
| 58 | 
         
            +
                def __init__(
         
     | 
| 59 | 
         
            +
                    self,
         
     | 
| 60 | 
         
            +
                    in_channels: int,
         
     | 
| 61 | 
         
            +
                    out_channels: int,
         
     | 
| 62 | 
         
            +
                    kernel_size=3,
         
     | 
| 63 | 
         
            +
                    dilation=1,
         
     | 
| 64 | 
         
            +
                    bias=False,
         
     | 
| 65 | 
         
            +
                    nonlinear_activation="ELU",
         
     | 
| 66 | 
         
            +
                    nonlinear_activation_params={},
         
     | 
| 67 | 
         
            +
                ):
         
     | 
| 68 | 
         
            +
                    super().__init__()
         
     | 
| 69 | 
         
            +
                    self.activation = getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
         
     | 
| 70 | 
         
            +
                    self.conv1 = Conv1d(
         
     | 
| 71 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 72 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 73 | 
         
            +
                        kernel_size=kernel_size,
         
     | 
| 74 | 
         
            +
                        stride=1,
         
     | 
| 75 | 
         
            +
                        dilation=dilation,
         
     | 
| 76 | 
         
            +
                        bias=bias,
         
     | 
| 77 | 
         
            +
                    )
         
     | 
| 78 | 
         
            +
                    self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                def forward(self, x):
         
     | 
| 81 | 
         
            +
                    y = self.conv1(self.activation(x))
         
     | 
| 82 | 
         
            +
                    y = self.conv2(self.activation(y))
         
     | 
| 83 | 
         
            +
                    return x + y
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
             
     | 
| 86 | 
         
            +
            class ConvTranspose1d(nn.Module):
         
     | 
| 87 | 
         
            +
                def __init__(
         
     | 
| 88 | 
         
            +
                    self,
         
     | 
| 89 | 
         
            +
                    in_channels: int,
         
     | 
| 90 | 
         
            +
                    out_channels: int,
         
     | 
| 91 | 
         
            +
                    kernel_size: int,
         
     | 
| 92 | 
         
            +
                    stride: int,
         
     | 
| 93 | 
         
            +
                    padding=-1,
         
     | 
| 94 | 
         
            +
                    output_padding=-1,
         
     | 
| 95 | 
         
            +
                    groups=1,
         
     | 
| 96 | 
         
            +
                    bias=True,
         
     | 
| 97 | 
         
            +
                ):
         
     | 
| 98 | 
         
            +
                    super().__init__()
         
     | 
| 99 | 
         
            +
                    if padding < 0:
         
     | 
| 100 | 
         
            +
                        padding = (stride + 1) // 2
         
     | 
| 101 | 
         
            +
                    if output_padding < 0:
         
     | 
| 102 | 
         
            +
                        output_padding = 1 if stride % 2 else 0
         
     | 
| 103 | 
         
            +
                    self.deconv = nn.ConvTranspose1d(
         
     | 
| 104 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 105 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 106 | 
         
            +
                        kernel_size=kernel_size,
         
     | 
| 107 | 
         
            +
                        stride=stride,
         
     | 
| 108 | 
         
            +
                        padding=padding,
         
     | 
| 109 | 
         
            +
                        output_padding=output_padding,
         
     | 
| 110 | 
         
            +
                        groups=groups,
         
     | 
| 111 | 
         
            +
                        bias=bias,
         
     | 
| 112 | 
         
            +
                    )
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                def forward(self, x):
         
     | 
| 115 | 
         
            +
                    """
         
     | 
| 116 | 
         
            +
                    Args:
         
     | 
| 117 | 
         
            +
                        x (Tensor): Float tensor variable with the shape  (B, C, T).
         
     | 
| 118 | 
         
            +
                    Returns:
         
     | 
| 119 | 
         
            +
                        Tensor: Float tensor variable with the shape (B, C', T').
         
     | 
| 120 | 
         
            +
                    """
         
     | 
| 121 | 
         
            +
                    x = self.deconv(x)
         
     | 
| 122 | 
         
            +
                    return x
         
     | 
| 123 | 
         
            +
             
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            class EncoderBlock(nn.Module):
         
     | 
| 126 | 
         
            +
                def __init__(
         
     | 
| 127 | 
         
            +
                    self,
         
     | 
| 128 | 
         
            +
                    in_channels: int,
         
     | 
| 129 | 
         
            +
                    out_channels: int,
         
     | 
| 130 | 
         
            +
                    stride: int,
         
     | 
| 131 | 
         
            +
                    dilations=(1, 1),
         
     | 
| 132 | 
         
            +
                    unit_kernel_size=3,
         
     | 
| 133 | 
         
            +
                    bias=True,
         
     | 
| 134 | 
         
            +
                ):
         
     | 
| 135 | 
         
            +
                    super().__init__()
         
     | 
| 136 | 
         
            +
                    self.res_units = torch.nn.ModuleList()
         
     | 
| 137 | 
         
            +
                    for dilation in dilations:
         
     | 
| 138 | 
         
            +
                        self.res_units += [
         
     | 
| 139 | 
         
            +
                            ResidualUnit(
         
     | 
| 140 | 
         
            +
                                in_channels,
         
     | 
| 141 | 
         
            +
                                in_channels,
         
     | 
| 142 | 
         
            +
                                kernel_size=unit_kernel_size,
         
     | 
| 143 | 
         
            +
                                dilation=dilation,
         
     | 
| 144 | 
         
            +
                            )
         
     | 
| 145 | 
         
            +
                        ]
         
     | 
| 146 | 
         
            +
                    self.num_res = len(self.res_units)
         
     | 
| 147 | 
         
            +
             
     | 
| 148 | 
         
            +
                    self.conv = Conv1d(
         
     | 
| 149 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 150 | 
         
            +
                        out_channels=out_channels,
         
     | 
| 151 | 
         
            +
                        kernel_size=3 if stride == 1 else (2 * stride),  # special case: stride=1, do not use kernel=2
         
     | 
| 152 | 
         
            +
                        stride=stride,
         
     | 
| 153 | 
         
            +
                        bias=bias,
         
     | 
| 154 | 
         
            +
                    )
         
     | 
| 155 | 
         
            +
             
     | 
| 156 | 
         
            +
                def forward(self, x):
         
     | 
| 157 | 
         
            +
                    for idx in range(self.num_res):
         
     | 
| 158 | 
         
            +
                        x = self.res_units[idx](x)
         
     | 
| 159 | 
         
            +
                    x = self.conv(x)
         
     | 
| 160 | 
         
            +
                    return x
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            class Encoder(nn.Module):
         
     | 
| 164 | 
         
            +
                def __init__(
         
     | 
| 165 | 
         
            +
                    self,
         
     | 
| 166 | 
         
            +
                    input_channels: int,
         
     | 
| 167 | 
         
            +
                    encode_channels: int,
         
     | 
| 168 | 
         
            +
                    channel_ratios=(1, 1),
         
     | 
| 169 | 
         
            +
                    strides=(1, 1),
         
     | 
| 170 | 
         
            +
                    kernel_size=3,
         
     | 
| 171 | 
         
            +
                    bias=True,
         
     | 
| 172 | 
         
            +
                    block_dilations=(1, 1),
         
     | 
| 173 | 
         
            +
                    unit_kernel_size=3,
         
     | 
| 174 | 
         
            +
                ):
         
     | 
| 175 | 
         
            +
                    super().__init__()
         
     | 
| 176 | 
         
            +
                    assert len(channel_ratios) == len(strides)
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                    self.conv = Conv1d(
         
     | 
| 179 | 
         
            +
                        in_channels=input_channels,
         
     | 
| 180 | 
         
            +
                        out_channels=encode_channels,
         
     | 
| 181 | 
         
            +
                        kernel_size=kernel_size,
         
     | 
| 182 | 
         
            +
                        stride=1,
         
     | 
| 183 | 
         
            +
                        bias=False,
         
     | 
| 184 | 
         
            +
                    )
         
     | 
| 185 | 
         
            +
                    self.conv_blocks = torch.nn.ModuleList()
         
     | 
| 186 | 
         
            +
                    in_channels = encode_channels
         
     | 
| 187 | 
         
            +
                    for idx, stride in enumerate(strides):
         
     | 
| 188 | 
         
            +
                        out_channels = int(encode_channels * channel_ratios[idx])  # could be float
         
     | 
| 189 | 
         
            +
                        self.conv_blocks += [
         
     | 
| 190 | 
         
            +
                            EncoderBlock(
         
     | 
| 191 | 
         
            +
                                in_channels,
         
     | 
| 192 | 
         
            +
                                out_channels,
         
     | 
| 193 | 
         
            +
                                stride,
         
     | 
| 194 | 
         
            +
                                dilations=block_dilations,
         
     | 
| 195 | 
         
            +
                                unit_kernel_size=unit_kernel_size,
         
     | 
| 196 | 
         
            +
                                bias=bias,
         
     | 
| 197 | 
         
            +
                            )
         
     | 
| 198 | 
         
            +
                        ]
         
     | 
| 199 | 
         
            +
                        in_channels = out_channels
         
     | 
| 200 | 
         
            +
                    self.num_blocks = len(self.conv_blocks)
         
     | 
| 201 | 
         
            +
                    self.out_channels = out_channels
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                def forward(self, x):
         
     | 
| 204 | 
         
            +
                    x = self.conv(x)
         
     | 
| 205 | 
         
            +
                    for i in range(self.num_blocks):
         
     | 
| 206 | 
         
            +
                        x = self.conv_blocks[i](x)
         
     | 
| 207 | 
         
            +
                    return x
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
             
     | 
| 210 | 
         
            +
            class DecoderBlock(nn.Module):
         
     | 
| 211 | 
         
            +
                """Decoder block (no up-sampling)"""
         
     | 
| 212 | 
         
            +
             
     | 
| 213 | 
         
            +
                def __init__(
         
     | 
| 214 | 
         
            +
                    self,
         
     | 
| 215 | 
         
            +
                    in_channels: int,
         
     | 
| 216 | 
         
            +
                    out_channels: int,
         
     | 
| 217 | 
         
            +
                    stride: int,
         
     | 
| 218 | 
         
            +
                    dilations=(1, 1),
         
     | 
| 219 | 
         
            +
                    unit_kernel_size=3,
         
     | 
| 220 | 
         
            +
                    bias=True,
         
     | 
| 221 | 
         
            +
                ):
         
     | 
| 222 | 
         
            +
                    super().__init__()
         
     | 
| 223 | 
         
            +
             
     | 
| 224 | 
         
            +
                    if stride == 1:
         
     | 
| 225 | 
         
            +
                        self.conv = Conv1d(
         
     | 
| 226 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 227 | 
         
            +
                            out_channels=out_channels,
         
     | 
| 228 | 
         
            +
                            kernel_size=3,  # fix kernel=3 when stride=1 for unchanged shape
         
     | 
| 229 | 
         
            +
                            stride=stride,
         
     | 
| 230 | 
         
            +
                            bias=bias,
         
     | 
| 231 | 
         
            +
                        )
         
     | 
| 232 | 
         
            +
                    else:
         
     | 
| 233 | 
         
            +
                        self.conv = ConvTranspose1d(
         
     | 
| 234 | 
         
            +
                            in_channels=in_channels,
         
     | 
| 235 | 
         
            +
                            out_channels=out_channels,
         
     | 
| 236 | 
         
            +
                            kernel_size=(2 * stride),
         
     | 
| 237 | 
         
            +
                            stride=stride,
         
     | 
| 238 | 
         
            +
                            bias=bias,
         
     | 
| 239 | 
         
            +
                        )
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                    self.res_units = torch.nn.ModuleList()
         
     | 
| 242 | 
         
            +
                    for idx, dilation in enumerate(dilations):
         
     | 
| 243 | 
         
            +
                        self.res_units += [
         
     | 
| 244 | 
         
            +
                            ResidualUnit(
         
     | 
| 245 | 
         
            +
                                out_channels,
         
     | 
| 246 | 
         
            +
                                out_channels,
         
     | 
| 247 | 
         
            +
                                kernel_size=unit_kernel_size,
         
     | 
| 248 | 
         
            +
                                dilation=dilation,
         
     | 
| 249 | 
         
            +
                            )
         
     | 
| 250 | 
         
            +
                        ]
         
     | 
| 251 | 
         
            +
                    self.num_res = len(self.res_units)
         
     | 
| 252 | 
         
            +
             
     | 
| 253 | 
         
            +
                def forward(self, x):
         
     | 
| 254 | 
         
            +
                    x = self.conv(x)
         
     | 
| 255 | 
         
            +
                    for idx in range(self.num_res):
         
     | 
| 256 | 
         
            +
                        x = self.res_units[idx](x)
         
     | 
| 257 | 
         
            +
                    return x
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
            class Decoder(nn.Module):
         
     | 
| 261 | 
         
            +
                def __init__(
         
     | 
| 262 | 
         
            +
                    self,
         
     | 
| 263 | 
         
            +
                    code_dim: int,
         
     | 
| 264 | 
         
            +
                    output_channels: int,
         
     | 
| 265 | 
         
            +
                    decode_channels: int,
         
     | 
| 266 | 
         
            +
                    channel_ratios=(1, 1),
         
     | 
| 267 | 
         
            +
                    strides=(1, 1),
         
     | 
| 268 | 
         
            +
                    kernel_size=3,
         
     | 
| 269 | 
         
            +
                    bias=True,
         
     | 
| 270 | 
         
            +
                    block_dilations=(1, 1),
         
     | 
| 271 | 
         
            +
                    unit_kernel_size=3,
         
     | 
| 272 | 
         
            +
                ):
         
     | 
| 273 | 
         
            +
                    super().__init__()
         
     | 
| 274 | 
         
            +
                    assert len(channel_ratios) == len(strides)
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                    self.conv1 = Conv1d(
         
     | 
| 277 | 
         
            +
                        in_channels=code_dim,
         
     | 
| 278 | 
         
            +
                        out_channels=int(decode_channels * channel_ratios[0]),
         
     | 
| 279 | 
         
            +
                        kernel_size=kernel_size,
         
     | 
| 280 | 
         
            +
                        stride=1,
         
     | 
| 281 | 
         
            +
                        bias=False,
         
     | 
| 282 | 
         
            +
                    )
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    self.conv_blocks = torch.nn.ModuleList()
         
     | 
| 285 | 
         
            +
                    for idx, stride in enumerate(strides):
         
     | 
| 286 | 
         
            +
                        in_channels = int(decode_channels * channel_ratios[idx])
         
     | 
| 287 | 
         
            +
                        if idx < (len(channel_ratios) - 1):
         
     | 
| 288 | 
         
            +
                            out_channels = int(decode_channels * channel_ratios[idx + 1])
         
     | 
| 289 | 
         
            +
                        else:
         
     | 
| 290 | 
         
            +
                            out_channels = decode_channels
         
     | 
| 291 | 
         
            +
                        self.conv_blocks += [
         
     | 
| 292 | 
         
            +
                            DecoderBlock(
         
     | 
| 293 | 
         
            +
                                in_channels,
         
     | 
| 294 | 
         
            +
                                out_channels,
         
     | 
| 295 | 
         
            +
                                stride,
         
     | 
| 296 | 
         
            +
                                dilations=block_dilations,
         
     | 
| 297 | 
         
            +
                                unit_kernel_size=unit_kernel_size,
         
     | 
| 298 | 
         
            +
                                bias=bias,
         
     | 
| 299 | 
         
            +
                            )
         
     | 
| 300 | 
         
            +
                        ]
         
     | 
| 301 | 
         
            +
                    self.num_blocks = len(self.conv_blocks)
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
                    self.conv2 = Conv1d(out_channels, output_channels, kernel_size, 1, bias=False)
         
     | 
| 304 | 
         
            +
             
     | 
| 305 | 
         
            +
                def forward(self, z):
         
     | 
| 306 | 
         
            +
                    x = self.conv1(z)
         
     | 
| 307 | 
         
            +
                    for i in range(self.num_blocks):
         
     | 
| 308 | 
         
            +
                        x = self.conv_blocks[i](x)
         
     | 
| 309 | 
         
            +
                    x = self.conv2(x)
         
     | 
| 310 | 
         
            +
                    return x
         
     | 
    	
        higgs_audio/constants.py
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            AUDIO_IN_TOKEN = "<|AUDIO|>"
         
     | 
| 2 | 
         
            +
            AUDIO_OUT_TOKEN = "<|AUDIO_OUT|>"
         
     | 
| 3 | 
         
            +
            EOS_TOKEN = "<|end_of_text|>"
         
     | 
    	
        higgs_audio/data_collator/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        higgs_audio/data_collator/higgs_audio_collator.py
    ADDED
    
    | 
         @@ -0,0 +1,583 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import librosa
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 4 | 
         
            +
            import math
         
     | 
| 5 | 
         
            +
            import numpy as np
         
     | 
| 6 | 
         
            +
            from typing import List, Tuple, Dict
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 9 | 
         
            +
            from typing import List, Optional
         
     | 
| 10 | 
         
            +
            from transformers.models.whisper.processing_whisper import WhisperProcessor
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from ..dataset.chatml_dataset import ChatMLDatasetSample, RankedChatMLDatasetSampleTuple
         
     | 
| 13 | 
         
            +
            from ..model.utils import build_delay_pattern_mask
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            def _ceil_to_nearest(n, round_to):
         
     | 
| 17 | 
         
            +
                return (n + round_to - 1) // round_to * round_to
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
            @dataclass
         
     | 
| 21 | 
         
            +
            class HiggsAudioBatchInput:
         
     | 
| 22 | 
         
            +
                input_ids: torch.LongTensor  # shape (bsz, seq_len).
         
     | 
| 23 | 
         
            +
                attention_mask: torch.Tensor  # shape (bsz, seq_len).
         
     | 
| 24 | 
         
            +
                audio_features: Optional[torch.Tensor]  # shape (num_audio_in, feature_dim, max_mel_seq_len).
         
     | 
| 25 | 
         
            +
                audio_feature_attention_mask: Optional[torch.Tensor]  # shape (num_audio_in, max_mel_seq_len).
         
     | 
| 26 | 
         
            +
                audio_out_ids: Optional[torch.LongTensor]  # shape (num_codebooks, audio_out_total_length)
         
     | 
| 27 | 
         
            +
                audio_out_ids_start: Optional[torch.LongTensor]  # shape (num_audio_out,)
         
     | 
| 28 | 
         
            +
                # The audio_out_ids_start_group_loc has the same length as audio_out_ids_start. It is used to recover group location in a batch for an audio segment
         
     | 
| 29 | 
         
            +
                # Currently, we concatenante audio segments along dim 0 to handle variadic audio segment length. However, in the alignment stage, we need the location information
         
     | 
| 30 | 
         
            +
                # For example,
         
     | 
| 31 | 
         
            +
                #  audio_out_ids_start = [0, 2, 4, 8]; and the first two audio segments come from the same sample in a batch, and other two come from different samples.
         
     | 
| 32 | 
         
            +
                #  This is a batch of 3 samples, then we will have the group location as:
         
     | 
| 33 | 
         
            +
                #  audio_out_ids_start_group_loc = [0, 0, 1, 2]
         
     | 
| 34 | 
         
            +
                audio_out_ids_start_group_loc: Optional[
         
     | 
| 35 | 
         
            +
                    torch.LongTensor
         
     | 
| 36 | 
         
            +
                ]  # shape (num_audio_out,), specify which a sample's group location in the batch
         
     | 
| 37 | 
         
            +
                audio_in_ids: Optional[torch.LongTensor]  # shape (num_codebooks, audio_in_total_length)
         
     | 
| 38 | 
         
            +
                audio_in_ids_start: Optional[torch.LongTensor]  # shape (num_audio_in,)
         
     | 
| 39 | 
         
            +
                label_ids: Optional[torch.LongTensor]  # shape (bsz, seq_len)
         
     | 
| 40 | 
         
            +
                label_audio_ids: Optional[torch.LongTensor]  # shape (num_codebooks, audio_out_total_length)
         
     | 
| 41 | 
         
            +
                reward: Optional[float] = None
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
            class HiggsAudioSampleCollator:
         
     | 
| 45 | 
         
            +
                """Sample collator for Higgs-Audio model.
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                Args:
         
     | 
| 48 | 
         
            +
                    whisper_processor (WhisperProcessor): The whisper processor.
         
     | 
| 49 | 
         
            +
                    audio_in_token_id (int): The token id for audio-in.
         
     | 
| 50 | 
         
            +
                    audio_out_token_id (int): The token id for audio-out.
         
     | 
| 51 | 
         
            +
                    pad_token_id (int): The token id for padding.
         
     | 
| 52 | 
         
            +
                    audio_stream_bos_id (int): The token id for audio-stream beginning of sentence.
         
     | 
| 53 | 
         
            +
                    audio_stream_eos_id (int): The token id for audio-stream end of sentence.
         
     | 
| 54 | 
         
            +
                    round_to (int): The round-to value.
         
     | 
| 55 | 
         
            +
                    pad_left (bool): Whether to pad left.
         
     | 
| 56 | 
         
            +
                    return_audio_in_tokens (bool): Whether to return audio-in tokens.
         
     | 
| 57 | 
         
            +
                    use_delay_pattern (bool): Whether to use delay pattern.
         
     | 
| 58 | 
         
            +
                    disable_audio_codes_transform (bool): Whether to add bos and eos tokens to audio codes.
         
     | 
| 59 | 
         
            +
                    chunk_size_seconds (int): The chunk size in seconds.
         
     | 
| 60 | 
         
            +
                    add_new_bos_eos_for_long_chunk (bool): Whether to add new bos and eos tokens for long chunks.
         
     | 
| 61 | 
         
            +
                    mask_audio_out_token_label (bool): Whether to always mask the label associated with <|AUDIO_OUT|> token. Since we will always have `<|AUDIO_OUT|>` after `<|audio_bos|>`, we can safely mask <|AUDIO_OUT|>.
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                """
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                def __init__(
         
     | 
| 66 | 
         
            +
                    self,
         
     | 
| 67 | 
         
            +
                    whisper_processor: WhisperProcessor,
         
     | 
| 68 | 
         
            +
                    audio_in_token_id,
         
     | 
| 69 | 
         
            +
                    audio_out_token_id,
         
     | 
| 70 | 
         
            +
                    pad_token_id,
         
     | 
| 71 | 
         
            +
                    audio_stream_bos_id,
         
     | 
| 72 | 
         
            +
                    audio_stream_eos_id,
         
     | 
| 73 | 
         
            +
                    round_to=8,
         
     | 
| 74 | 
         
            +
                    pad_left=False,
         
     | 
| 75 | 
         
            +
                    encode_whisper_embed=True,
         
     | 
| 76 | 
         
            +
                    return_audio_in_tokens=True,
         
     | 
| 77 | 
         
            +
                    audio_num_codebooks=None,
         
     | 
| 78 | 
         
            +
                    use_delay_pattern=False,
         
     | 
| 79 | 
         
            +
                    disable_audio_codes_transform=False,
         
     | 
| 80 | 
         
            +
                    chunk_size_seconds=30,  # Maximum duration for each chunk
         
     | 
| 81 | 
         
            +
                    add_new_bos_eos_for_long_chunk=True,
         
     | 
| 82 | 
         
            +
                    mask_audio_out_token_label=True,
         
     | 
| 83 | 
         
            +
                ):
         
     | 
| 84 | 
         
            +
                    self.whisper_processor = whisper_processor
         
     | 
| 85 | 
         
            +
                    self.round_to = round_to
         
     | 
| 86 | 
         
            +
                    self.pad_left = pad_left
         
     | 
| 87 | 
         
            +
                    self.audio_in_token_id = audio_in_token_id
         
     | 
| 88 | 
         
            +
                    self.audio_out_token_id = audio_out_token_id
         
     | 
| 89 | 
         
            +
                    self.audio_stream_bos_id = audio_stream_bos_id
         
     | 
| 90 | 
         
            +
                    self.audio_stream_eos_id = audio_stream_eos_id
         
     | 
| 91 | 
         
            +
                    self.pad_token_id = pad_token_id
         
     | 
| 92 | 
         
            +
                    self.encode_whisper_embed = encode_whisper_embed
         
     | 
| 93 | 
         
            +
                    self.return_audio_in_tokens = return_audio_in_tokens
         
     | 
| 94 | 
         
            +
                    self.audio_num_codebooks = audio_num_codebooks
         
     | 
| 95 | 
         
            +
                    self.use_delay_pattern = use_delay_pattern
         
     | 
| 96 | 
         
            +
                    if encode_whisper_embed:
         
     | 
| 97 | 
         
            +
                        self.chunk_size_seconds = chunk_size_seconds
         
     | 
| 98 | 
         
            +
                        self.chunk_size_samples = int(chunk_size_seconds * whisper_processor.feature_extractor.sampling_rate)
         
     | 
| 99 | 
         
            +
                    else:
         
     | 
| 100 | 
         
            +
                        self.chunk_size_seconds = None
         
     | 
| 101 | 
         
            +
                        self.chunk_size_samples = None
         
     | 
| 102 | 
         
            +
                    self.disable_audio_codes_transform = disable_audio_codes_transform
         
     | 
| 103 | 
         
            +
                    self.add_new_bos_eos_for_long_chunk = add_new_bos_eos_for_long_chunk
         
     | 
| 104 | 
         
            +
                    self.mask_audio_out_token_label = mask_audio_out_token_label
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def _process_and_duplicate_audio_tokens(
         
     | 
| 107 | 
         
            +
                    self,
         
     | 
| 108 | 
         
            +
                    input_ids: torch.Tensor,
         
     | 
| 109 | 
         
            +
                    audio_idx: int,
         
     | 
| 110 | 
         
            +
                    wv: torch.Tensor,
         
     | 
| 111 | 
         
            +
                    sr: int,
         
     | 
| 112 | 
         
            +
                    labels: Optional[torch.Tensor] = None,
         
     | 
| 113 | 
         
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, int]:
         
     | 
| 114 | 
         
            +
                    """Process long audio and duplicate corresponding audio tokens.
         
     | 
| 115 | 
         
            +
             
     | 
| 116 | 
         
            +
                    Args:
         
     | 
| 117 | 
         
            +
                        input_ids: Input token ids
         
     | 
| 118 | 
         
            +
                        audio_idx: Index of the audio token in the sequence
         
     | 
| 119 | 
         
            +
                        wv: Audio waveform
         
     | 
| 120 | 
         
            +
                        sr: Sample rate
         
     | 
| 121 | 
         
            +
                        labels: Optional label ids to be duplicated alongside input ids
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    Returns:
         
     | 
| 124 | 
         
            +
                        Tuple of:
         
     | 
| 125 | 
         
            +
                            - New input ids with duplicated audio tokens
         
     | 
| 126 | 
         
            +
                            - New label ids (if labels were provided) or None
         
     | 
| 127 | 
         
            +
                            - Number of chunks created
         
     | 
| 128 | 
         
            +
                    """
         
     | 
| 129 | 
         
            +
                    # Calculate number of chunks needed
         
     | 
| 130 | 
         
            +
                    total_samples = len(wv)
         
     | 
| 131 | 
         
            +
                    num_chunks = math.ceil(total_samples / self.chunk_size_samples)
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    if num_chunks <= 1:
         
     | 
| 134 | 
         
            +
                        return input_ids, labels, 1
         
     | 
| 135 | 
         
            +
             
     | 
| 136 | 
         
            +
                    # Get the three tokens: <|audio_bos|><|AUDIO|><|audio_eos|>
         
     | 
| 137 | 
         
            +
                    audio_token_seq = input_ids[audio_idx - 1 : audio_idx + 2]
         
     | 
| 138 | 
         
            +
                    # Duplicate sequence for each chunk
         
     | 
| 139 | 
         
            +
                    duplicated_sequence = audio_token_seq.repeat(num_chunks)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    # Create new input_ids with duplicated tokens
         
     | 
| 142 | 
         
            +
                    new_input_ids = torch.cat(
         
     | 
| 143 | 
         
            +
                        [
         
     | 
| 144 | 
         
            +
                            input_ids[: audio_idx - 1],
         
     | 
| 145 | 
         
            +
                            duplicated_sequence,
         
     | 
| 146 | 
         
            +
                            input_ids[audio_idx + 2 :],
         
     | 
| 147 | 
         
            +
                        ]
         
     | 
| 148 | 
         
            +
                    )
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # If labels are provided, duplicate them as well
         
     | 
| 151 | 
         
            +
                    new_labels = None
         
     | 
| 152 | 
         
            +
                    if labels is not None:
         
     | 
| 153 | 
         
            +
                        label_seq = labels[audio_idx - 1 : audio_idx + 2]
         
     | 
| 154 | 
         
            +
                        duplicated_labels = label_seq.repeat(num_chunks)
         
     | 
| 155 | 
         
            +
                        new_labels = torch.cat([labels[: audio_idx - 1], duplicated_labels, labels[audio_idx + 2 :]])
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    return new_input_ids, new_labels, num_chunks
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                def __call__(self, batch: List[ChatMLDatasetSample]):
         
     | 
| 160 | 
         
            +
                    """Collate the input data with support for long audio processing."""
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    label_ids = None
         
     | 
| 163 | 
         
            +
                    label_audio_ids = None
         
     | 
| 164 | 
         
            +
                    if all([ele.label_ids is None for ele in batch]):
         
     | 
| 165 | 
         
            +
                        return_labels = False
         
     | 
| 166 | 
         
            +
                    else:
         
     | 
| 167 | 
         
            +
                        return_labels = True
         
     | 
| 168 | 
         
            +
             
     | 
| 169 | 
         
            +
                    if self.encode_whisper_embed:
         
     | 
| 170 | 
         
            +
                        # Process each sample in the batch to handle long audio
         
     | 
| 171 | 
         
            +
                        # TODO(?) The implementation here can be optimized.
         
     | 
| 172 | 
         
            +
                        processed_batch = []
         
     | 
| 173 | 
         
            +
                        for i in range(len(batch)):
         
     | 
| 174 | 
         
            +
                            sample = batch[i]
         
     | 
| 175 | 
         
            +
                            audio_in_mask = sample.input_ids == self.audio_in_token_id
         
     | 
| 176 | 
         
            +
                            audio_in_indices = torch.where(audio_in_mask)[0]
         
     | 
| 177 | 
         
            +
                            audio_out_mask = sample.input_ids == self.audio_out_token_id
         
     | 
| 178 | 
         
            +
             
     | 
| 179 | 
         
            +
                            # Process each audio token and duplicate if needed
         
     | 
| 180 | 
         
            +
                            modified_input_ids = sample.input_ids
         
     | 
| 181 | 
         
            +
                            modified_labels = sample.label_ids if return_labels else None
         
     | 
| 182 | 
         
            +
                            modified_waveforms_concat = []
         
     | 
| 183 | 
         
            +
                            modified_waveforms_start = []
         
     | 
| 184 | 
         
            +
                            modified_sample_rate = []
         
     | 
| 185 | 
         
            +
                            offset = 0  # Track position changes from duplicating tokens
         
     | 
| 186 | 
         
            +
                            curr_wv_offset = 0
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                            # Process input audio tokens
         
     | 
| 189 | 
         
            +
                            for idx, audio_idx in enumerate(audio_in_indices):
         
     | 
| 190 | 
         
            +
                                # Get the audio for this token
         
     | 
| 191 | 
         
            +
                                wv, sr = sample.get_wv(idx)  # Use idx since we want the original audio index
         
     | 
| 192 | 
         
            +
                                if sr != self.whisper_processor.feature_extractor.sampling_rate:
         
     | 
| 193 | 
         
            +
                                    resampled_wv = librosa.resample(
         
     | 
| 194 | 
         
            +
                                        wv.cpu().numpy(),
         
     | 
| 195 | 
         
            +
                                        orig_sr=sr,
         
     | 
| 196 | 
         
            +
                                        target_sr=self.whisper_processor.feature_extractor.sampling_rate,
         
     | 
| 197 | 
         
            +
                                    )
         
     | 
| 198 | 
         
            +
                                else:
         
     | 
| 199 | 
         
            +
                                    resampled_wv = wv.cpu().numpy()
         
     | 
| 200 | 
         
            +
                                wv = torch.tensor(resampled_wv, device=wv.device)
         
     | 
| 201 | 
         
            +
                                sr = self.whisper_processor.feature_extractor.sampling_rate
         
     | 
| 202 | 
         
            +
             
     | 
| 203 | 
         
            +
                                # Process and duplicate tokens if necessary
         
     | 
| 204 | 
         
            +
                                token_pos = audio_idx + offset
         
     | 
| 205 | 
         
            +
                                modified_input_ids, modified_labels, num_chunks = self._process_and_duplicate_audio_tokens(
         
     | 
| 206 | 
         
            +
                                    modified_input_ids, token_pos, wv, sr, modified_labels
         
     | 
| 207 | 
         
            +
                                )
         
     | 
| 208 | 
         
            +
             
     | 
| 209 | 
         
            +
                                # Update audio data
         
     | 
| 210 | 
         
            +
                                for chunk_idx in range(num_chunks):
         
     | 
| 211 | 
         
            +
                                    chunk_start = chunk_idx * self.chunk_size_samples
         
     | 
| 212 | 
         
            +
                                    chunk_end = min((chunk_idx + 1) * self.chunk_size_samples, len(wv))
         
     | 
| 213 | 
         
            +
                                    chunk_wv = wv[chunk_start:chunk_end]
         
     | 
| 214 | 
         
            +
                                    modified_waveforms_concat.append(chunk_wv)
         
     | 
| 215 | 
         
            +
                                    modified_waveforms_start.append(curr_wv_offset)
         
     | 
| 216 | 
         
            +
                                    curr_wv_offset += len(chunk_wv)
         
     | 
| 217 | 
         
            +
                                    modified_sample_rate.append(sr)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
                                # Update offset for next iteration
         
     | 
| 220 | 
         
            +
                                offset += (num_chunks - 1) * 3  # Each new chunk adds 3 more tokens
         
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
                            # Create new sample with modified tokens and audio data
         
     | 
| 223 | 
         
            +
                            processed_sample = ChatMLDatasetSample(
         
     | 
| 224 | 
         
            +
                                input_ids=modified_input_ids,
         
     | 
| 225 | 
         
            +
                                label_ids=modified_labels if return_labels else sample.label_ids,
         
     | 
| 226 | 
         
            +
                                audio_ids_concat=sample.audio_ids_concat,
         
     | 
| 227 | 
         
            +
                                audio_ids_start=sample.audio_ids_start,
         
     | 
| 228 | 
         
            +
                                audio_waveforms_concat=torch.cat(modified_waveforms_concat)
         
     | 
| 229 | 
         
            +
                                if modified_waveforms_concat
         
     | 
| 230 | 
         
            +
                                else sample.audio_waveforms_concat,
         
     | 
| 231 | 
         
            +
                                audio_waveforms_start=torch.tensor(modified_waveforms_start, dtype=torch.long)
         
     | 
| 232 | 
         
            +
                                if modified_waveforms_start
         
     | 
| 233 | 
         
            +
                                else sample.audio_waveforms_start,
         
     | 
| 234 | 
         
            +
                                audio_sample_rate=torch.tensor(modified_sample_rate)
         
     | 
| 235 | 
         
            +
                                if modified_sample_rate
         
     | 
| 236 | 
         
            +
                                else sample.audio_sample_rate,
         
     | 
| 237 | 
         
            +
                                audio_speaker_indices=torch.tensor([]),
         
     | 
| 238 | 
         
            +
                                # FIXME(sxjscience): The logic here is not correct for audio_label_ids_concat.
         
     | 
| 239 | 
         
            +
                                audio_label_ids_concat=sample.audio_label_ids_concat,
         
     | 
| 240 | 
         
            +
                            )
         
     | 
| 241 | 
         
            +
                            # audio_in_chunk_len = len(torch.where(modified_input_ids == self.audio_in_token_id)[0])
         
     | 
| 242 | 
         
            +
                            # assert audio_in_chunk_len == processed_sample.num_audios(), f"Mismatch: audio_in_chunk_len={audio_in_chunk_len}, processed_sample.num_audios()={processed_sample.num_audios()}"
         
     | 
| 243 | 
         
            +
                            processed_batch.append(processed_sample)
         
     | 
| 244 | 
         
            +
                    else:
         
     | 
| 245 | 
         
            +
                        processed_batch = batch
         
     | 
| 246 | 
         
            +
             
     | 
| 247 | 
         
            +
                    # Get the max sequence length based on processed batch
         
     | 
| 248 | 
         
            +
                    max_seq_length = _ceil_to_nearest(max([len(sample.input_ids) for sample in processed_batch]), self.round_to)
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    # Get the ids for audio-in and audio-out for each batch
         
     | 
| 251 | 
         
            +
                    audio_in_wv_l = []
         
     | 
| 252 | 
         
            +
                    audio_in_ids_l = []
         
     | 
| 253 | 
         
            +
                    audio_out_ids_l = []
         
     | 
| 254 | 
         
            +
                    audio_out_ids_group_loc_l = []
         
     | 
| 255 | 
         
            +
                    audio_in_label_ids_l = None
         
     | 
| 256 | 
         
            +
                    audio_out_label_ids_l = None
         
     | 
| 257 | 
         
            +
                    reward_l = []
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                    if return_labels:
         
     | 
| 260 | 
         
            +
                        audio_out_no_train_flag = []  # Whether the audio-out data should be trained on or not.
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    # Process the audio inputs and outputs
         
     | 
| 263 | 
         
            +
                    for i in range(len(processed_batch)):
         
     | 
| 264 | 
         
            +
                        audio_in_mask = processed_batch[i].input_ids == self.audio_in_token_id
         
     | 
| 265 | 
         
            +
                        audio_out_mask = processed_batch[i].input_ids == self.audio_out_token_id
         
     | 
| 266 | 
         
            +
                        audio_ids = torch.ones_like(processed_batch[i].input_ids)
         
     | 
| 267 | 
         
            +
                        audio_ids[audio_in_mask ^ audio_out_mask] = torch.cumsum(audio_ids[audio_in_mask ^ audio_out_mask], 0) - 1
         
     | 
| 268 | 
         
            +
                        audio_in_ids = audio_ids[audio_in_mask]
         
     | 
| 269 | 
         
            +
                        audio_out_ids = audio_ids[audio_out_mask]
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
                        if return_labels:
         
     | 
| 272 | 
         
            +
                            audio_out_no_train_flag.append(processed_batch[i].label_ids[audio_out_mask] < 0)
         
     | 
| 273 | 
         
            +
                            if self.mask_audio_out_token_label:
         
     | 
| 274 | 
         
            +
                                processed_batch[i].label_ids[audio_out_mask] = -100
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                        # Process audio inputs
         
     | 
| 277 | 
         
            +
                        if self.return_audio_in_tokens:
         
     | 
| 278 | 
         
            +
                            audio_in_ids_l.extend(
         
     | 
| 279 | 
         
            +
                                [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_in_ids]
         
     | 
| 280 | 
         
            +
                            )
         
     | 
| 281 | 
         
            +
                            if processed_batch[i].audio_label_ids_concat is not None:
         
     | 
| 282 | 
         
            +
                                if audio_in_label_ids_l is None:
         
     | 
| 283 | 
         
            +
                                    audio_in_label_ids_l = []
         
     | 
| 284 | 
         
            +
                                audio_in_label_ids_l.extend(
         
     | 
| 285 | 
         
            +
                                    [
         
     | 
| 286 | 
         
            +
                                        processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
         
     | 
| 287 | 
         
            +
                                        for idx in audio_in_ids
         
     | 
| 288 | 
         
            +
                                    ]
         
     | 
| 289 | 
         
            +
                                )
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                        audio_out_ids_l.extend(
         
     | 
| 292 | 
         
            +
                            [processed_batch[i].get_audio_codes(idx)[: self.audio_num_codebooks, :] for idx in audio_out_ids]
         
     | 
| 293 | 
         
            +
                        )
         
     | 
| 294 | 
         
            +
                        audio_out_ids_group_loc_l.append(i)
         
     | 
| 295 | 
         
            +
                        if processed_batch[i].reward is not None:
         
     | 
| 296 | 
         
            +
                            reward_l.append(processed_batch[i].reward)
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                        if processed_batch[i].audio_label_ids_concat is not None:
         
     | 
| 299 | 
         
            +
                            if audio_out_label_ids_l is None:
         
     | 
| 300 | 
         
            +
                                audio_out_label_ids_l = []
         
     | 
| 301 | 
         
            +
                            audio_out_label_ids_l.extend(
         
     | 
| 302 | 
         
            +
                                [
         
     | 
| 303 | 
         
            +
                                    processed_batch[i].get_audio_codes_labels(idx)[: self.audio_num_codebooks, :]
         
     | 
| 304 | 
         
            +
                                    for idx in audio_out_ids
         
     | 
| 305 | 
         
            +
                                ]
         
     | 
| 306 | 
         
            +
                            )
         
     | 
| 307 | 
         
            +
             
     | 
| 308 | 
         
            +
                        if self.encode_whisper_embed:
         
     | 
| 309 | 
         
            +
                            for idx in audio_in_ids:
         
     | 
| 310 | 
         
            +
                                wv, sr = processed_batch[i].get_wv(idx)
         
     | 
| 311 | 
         
            +
                                resampled_wv = wv.cpu().numpy()
         
     | 
| 312 | 
         
            +
                                # Split long audio into chunks
         
     | 
| 313 | 
         
            +
                                total_samples = len(resampled_wv)
         
     | 
| 314 | 
         
            +
                                for chunk_start in range(0, total_samples, self.chunk_size_samples):
         
     | 
| 315 | 
         
            +
                                    chunk_end = min(chunk_start + self.chunk_size_samples, total_samples)
         
     | 
| 316 | 
         
            +
                                    chunk = resampled_wv[chunk_start:chunk_end]
         
     | 
| 317 | 
         
            +
                                    audio_in_wv_l.append(chunk)
         
     | 
| 318 | 
         
            +
                        # assert len(audio_in_wv_l) == processed_batch[i].num_audios(), \
         
     | 
| 319 | 
         
            +
                        #     f"Assertion failed: Mismatch in number of audios. " \
         
     | 
| 320 | 
         
            +
                        #     f"Expected {processed_batch[i].num_audios()}, but got {len(audio_in_wv_l)} at index {i}."
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                    if return_labels:
         
     | 
| 323 | 
         
            +
                        audio_out_no_train_flag = torch.cat(audio_out_no_train_flag, dim=0)
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                    # Process all audio features
         
     | 
| 326 | 
         
            +
                    if len(audio_in_wv_l) > 0:
         
     | 
| 327 | 
         
            +
                        feature_ret = self.whisper_processor.feature_extractor(
         
     | 
| 328 | 
         
            +
                            audio_in_wv_l,
         
     | 
| 329 | 
         
            +
                            sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
         
     | 
| 330 | 
         
            +
                            return_attention_mask=True,
         
     | 
| 331 | 
         
            +
                            padding="max_length",
         
     | 
| 332 | 
         
            +
                        )
         
     | 
| 333 | 
         
            +
                        audio_features = torch.from_numpy(feature_ret["input_features"])
         
     | 
| 334 | 
         
            +
                        audio_feature_attention_mask = torch.from_numpy(feature_ret["attention_mask"])
         
     | 
| 335 | 
         
            +
                    else:
         
     | 
| 336 | 
         
            +
                        if self.encode_whisper_embed:
         
     | 
| 337 | 
         
            +
                            audio_features = torch.zeros(
         
     | 
| 338 | 
         
            +
                                (
         
     | 
| 339 | 
         
            +
                                    0,
         
     | 
| 340 | 
         
            +
                                    self.whisper_processor.feature_extractor.feature_size,
         
     | 
| 341 | 
         
            +
                                    self.whisper_processor.feature_extractor.nb_max_frames,
         
     | 
| 342 | 
         
            +
                                ),
         
     | 
| 343 | 
         
            +
                                dtype=torch.float32,
         
     | 
| 344 | 
         
            +
                            )
         
     | 
| 345 | 
         
            +
                            audio_feature_attention_mask = torch.zeros(
         
     | 
| 346 | 
         
            +
                                (0, self.whisper_processor.feature_extractor.nb_max_frames),
         
     | 
| 347 | 
         
            +
                                dtype=torch.int32,
         
     | 
| 348 | 
         
            +
                            )
         
     | 
| 349 | 
         
            +
                        else:
         
     | 
| 350 | 
         
            +
                            audio_features = None
         
     | 
| 351 | 
         
            +
                            audio_feature_attention_mask = None
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    # Process audio input tokens
         
     | 
| 354 | 
         
            +
                    if len(audio_in_ids_l) > 0:
         
     | 
| 355 | 
         
            +
                        # Append audio-stream-bos and eos tokens
         
     | 
| 356 | 
         
            +
                        new_audio_in_ids_l = []
         
     | 
| 357 | 
         
            +
                        for ele in audio_in_ids_l:
         
     | 
| 358 | 
         
            +
                            if self.disable_audio_codes_transform:
         
     | 
| 359 | 
         
            +
                                # Do not add audio-stream-bos or eos tokens.
         
     | 
| 360 | 
         
            +
                                # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
         
     | 
| 361 | 
         
            +
                                audio_codes = ele
         
     | 
| 362 | 
         
            +
                            else:
         
     | 
| 363 | 
         
            +
                                audio_codes = torch.cat(
         
     | 
| 364 | 
         
            +
                                    [
         
     | 
| 365 | 
         
            +
                                        torch.full(
         
     | 
| 366 | 
         
            +
                                            (ele.shape[0], 1),
         
     | 
| 367 | 
         
            +
                                            self.audio_stream_bos_id,
         
     | 
| 368 | 
         
            +
                                            dtype=torch.long,
         
     | 
| 369 | 
         
            +
                                        ),
         
     | 
| 370 | 
         
            +
                                        ele,
         
     | 
| 371 | 
         
            +
                                        torch.full(
         
     | 
| 372 | 
         
            +
                                            (ele.shape[0], 1),
         
     | 
| 373 | 
         
            +
                                            self.audio_stream_eos_id,
         
     | 
| 374 | 
         
            +
                                            dtype=torch.long,
         
     | 
| 375 | 
         
            +
                                        ),
         
     | 
| 376 | 
         
            +
                                    ],
         
     | 
| 377 | 
         
            +
                                    dim=1,
         
     | 
| 378 | 
         
            +
                                )
         
     | 
| 379 | 
         
            +
                                if self.use_delay_pattern:
         
     | 
| 380 | 
         
            +
                                    audio_codes = build_delay_pattern_mask(
         
     | 
| 381 | 
         
            +
                                        audio_codes.unsqueeze(0),
         
     | 
| 382 | 
         
            +
                                        bos_token_id=self.audio_stream_bos_id,
         
     | 
| 383 | 
         
            +
                                        pad_token_id=self.audio_stream_eos_id,
         
     | 
| 384 | 
         
            +
                                    )[0].squeeze(0)
         
     | 
| 385 | 
         
            +
                            new_audio_in_ids_l.append(audio_codes)
         
     | 
| 386 | 
         
            +
                        audio_in_ids = torch.cat(new_audio_in_ids_l, dim=1).long()
         
     | 
| 387 | 
         
            +
                        audio_in_ids_start = torch.cumsum(
         
     | 
| 388 | 
         
            +
                            torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_in_ids_l[:-1]]),
         
     | 
| 389 | 
         
            +
                            dim=0,
         
     | 
| 390 | 
         
            +
                        )
         
     | 
| 391 | 
         
            +
                    else:
         
     | 
| 392 | 
         
            +
                        audio_in_ids = torch.zeros((0, 0), dtype=torch.long)
         
     | 
| 393 | 
         
            +
                        audio_in_ids_start = torch.zeros(0, dtype=torch.long)
         
     | 
| 394 | 
         
            +
             
     | 
| 395 | 
         
            +
                    # Process audio output tokens
         
     | 
| 396 | 
         
            +
                    audio_out_ids_start_group_loc = None
         
     | 
| 397 | 
         
            +
                    if len(audio_out_ids_l) > 0:
         
     | 
| 398 | 
         
            +
                        new_audio_out_ids_l = []
         
     | 
| 399 | 
         
            +
                        label_audio_ids_l = []
         
     | 
| 400 | 
         
            +
                        for idx, ele in enumerate(audio_out_ids_l):
         
     | 
| 401 | 
         
            +
                            if self.disable_audio_codes_transform:
         
     | 
| 402 | 
         
            +
                                # Do not add audio-stream-bos or eos tokens.
         
     | 
| 403 | 
         
            +
                                # This may indicate that the sample comes from ConstantLengthDatasetWithBuffer.
         
     | 
| 404 | 
         
            +
                                audio_codes = ele
         
     | 
| 405 | 
         
            +
                                if return_labels:
         
     | 
| 406 | 
         
            +
                                    label_audio_ids = audio_out_label_ids_l[idx]
         
     | 
| 407 | 
         
            +
                            else:
         
     | 
| 408 | 
         
            +
                                audio_codes = torch.cat(
         
     | 
| 409 | 
         
            +
                                    [
         
     | 
| 410 | 
         
            +
                                        torch.full(
         
     | 
| 411 | 
         
            +
                                            (ele.shape[0], 1),
         
     | 
| 412 | 
         
            +
                                            self.audio_stream_bos_id,
         
     | 
| 413 | 
         
            +
                                            dtype=torch.long,
         
     | 
| 414 | 
         
            +
                                        ),
         
     | 
| 415 | 
         
            +
                                        ele,
         
     | 
| 416 | 
         
            +
                                        torch.full(
         
     | 
| 417 | 
         
            +
                                            (ele.shape[0], 1),
         
     | 
| 418 | 
         
            +
                                            self.audio_stream_eos_id,
         
     | 
| 419 | 
         
            +
                                            dtype=torch.long,
         
     | 
| 420 | 
         
            +
                                        ),
         
     | 
| 421 | 
         
            +
                                    ],
         
     | 
| 422 | 
         
            +
                                    dim=1,
         
     | 
| 423 | 
         
            +
                                )
         
     | 
| 424 | 
         
            +
                                if return_labels:
         
     | 
| 425 | 
         
            +
                                    label_audio_ids = torch.cat(
         
     | 
| 426 | 
         
            +
                                        [
         
     | 
| 427 | 
         
            +
                                            torch.full((ele.shape[0], 1), -100, dtype=torch.long),
         
     | 
| 428 | 
         
            +
                                            ele,
         
     | 
| 429 | 
         
            +
                                            torch.full(
         
     | 
| 430 | 
         
            +
                                                (ele.shape[0], 1),
         
     | 
| 431 | 
         
            +
                                                self.audio_stream_eos_id,
         
     | 
| 432 | 
         
            +
                                                dtype=torch.long,
         
     | 
| 433 | 
         
            +
                                            ),
         
     | 
| 434 | 
         
            +
                                        ],
         
     | 
| 435 | 
         
            +
                                        dim=1,
         
     | 
| 436 | 
         
            +
                                    )
         
     | 
| 437 | 
         
            +
                                if self.use_delay_pattern:
         
     | 
| 438 | 
         
            +
                                    audio_codes = build_delay_pattern_mask(
         
     | 
| 439 | 
         
            +
                                        audio_codes.unsqueeze(0),
         
     | 
| 440 | 
         
            +
                                        bos_token_id=self.audio_stream_bos_id,
         
     | 
| 441 | 
         
            +
                                        pad_token_id=self.audio_stream_eos_id,
         
     | 
| 442 | 
         
            +
                                    )[0].squeeze(0)
         
     | 
| 443 | 
         
            +
                                    if return_labels:
         
     | 
| 444 | 
         
            +
                                        label_audio_ids = build_delay_pattern_mask(
         
     | 
| 445 | 
         
            +
                                            label_audio_ids.unsqueeze(0),
         
     | 
| 446 | 
         
            +
                                            bos_token_id=-100,
         
     | 
| 447 | 
         
            +
                                            pad_token_id=-100,
         
     | 
| 448 | 
         
            +
                                        )[0].squeeze(0)
         
     | 
| 449 | 
         
            +
                            new_audio_out_ids_l.append(audio_codes)
         
     | 
| 450 | 
         
            +
             
     | 
| 451 | 
         
            +
                            if return_labels:
         
     | 
| 452 | 
         
            +
                                if audio_out_no_train_flag[idx]:
         
     | 
| 453 | 
         
            +
                                    label_audio_ids[:] = -100
         
     | 
| 454 | 
         
            +
                                label_audio_ids_l.append(label_audio_ids)
         
     | 
| 455 | 
         
            +
             
     | 
| 456 | 
         
            +
                        audio_out_ids = torch.cat(new_audio_out_ids_l, dim=1).long()
         
     | 
| 457 | 
         
            +
                        if return_labels:
         
     | 
| 458 | 
         
            +
                            label_audio_ids = torch.cat(label_audio_ids_l, dim=1).long()
         
     | 
| 459 | 
         
            +
                        audio_out_ids_start = torch.cumsum(
         
     | 
| 460 | 
         
            +
                            torch.tensor([0] + [audio_codes.shape[1] for audio_codes in new_audio_out_ids_l[:-1]]),
         
     | 
| 461 | 
         
            +
                            dim=0,
         
     | 
| 462 | 
         
            +
                        )
         
     | 
| 463 | 
         
            +
                        audio_out_ids_start_group_loc = torch.tensor(audio_out_ids_group_loc_l, dtype=torch.long)
         
     | 
| 464 | 
         
            +
                    else:
         
     | 
| 465 | 
         
            +
                        audio_out_ids = torch.zeros((0, 0), dtype=torch.long)
         
     | 
| 466 | 
         
            +
                        audio_out_ids_start = torch.zeros(0, dtype=torch.long)
         
     | 
| 467 | 
         
            +
                        if return_labels:
         
     | 
| 468 | 
         
            +
                            label_audio_ids = torch.zeros((0, 0), dtype=torch.long)
         
     | 
| 469 | 
         
            +
             
     | 
| 470 | 
         
            +
                    reward = torch.tensor(reward_l, dtype=torch.float32)
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
                    # Handle padding for input ids and attention mask
         
     | 
| 473 | 
         
            +
                    if self.pad_left:
         
     | 
| 474 | 
         
            +
                        input_ids = torch.stack(
         
     | 
| 475 | 
         
            +
                            [
         
     | 
| 476 | 
         
            +
                                F.pad(
         
     | 
| 477 | 
         
            +
                                    ele.input_ids,
         
     | 
| 478 | 
         
            +
                                    (max_seq_length - len(ele.input_ids), 0),
         
     | 
| 479 | 
         
            +
                                    value=self.pad_token_id,
         
     | 
| 480 | 
         
            +
                                )
         
     | 
| 481 | 
         
            +
                                for ele in processed_batch
         
     | 
| 482 | 
         
            +
                            ]
         
     | 
| 483 | 
         
            +
                        )
         
     | 
| 484 | 
         
            +
                        if return_labels:
         
     | 
| 485 | 
         
            +
                            label_ids = torch.stack(
         
     | 
| 486 | 
         
            +
                                [
         
     | 
| 487 | 
         
            +
                                    F.pad(
         
     | 
| 488 | 
         
            +
                                        ele.label_ids,
         
     | 
| 489 | 
         
            +
                                        (max_seq_length - len(ele.label_ids), 0),
         
     | 
| 490 | 
         
            +
                                        value=-100,
         
     | 
| 491 | 
         
            +
                                    )
         
     | 
| 492 | 
         
            +
                                    for ele in processed_batch
         
     | 
| 493 | 
         
            +
                                ]
         
     | 
| 494 | 
         
            +
                            )
         
     | 
| 495 | 
         
            +
                        attention_mask = torch.stack(
         
     | 
| 496 | 
         
            +
                            [
         
     | 
| 497 | 
         
            +
                                F.pad(
         
     | 
| 498 | 
         
            +
                                    torch.ones_like(ele.input_ids),
         
     | 
| 499 | 
         
            +
                                    (max_seq_length - len(ele.input_ids), 0),
         
     | 
| 500 | 
         
            +
                                    value=0,
         
     | 
| 501 | 
         
            +
                                )
         
     | 
| 502 | 
         
            +
                                for ele in processed_batch
         
     | 
| 503 | 
         
            +
                            ]
         
     | 
| 504 | 
         
            +
                        )
         
     | 
| 505 | 
         
            +
                    else:
         
     | 
| 506 | 
         
            +
                        input_ids = torch.stack(
         
     | 
| 507 | 
         
            +
                            [
         
     | 
| 508 | 
         
            +
                                F.pad(
         
     | 
| 509 | 
         
            +
                                    ele.input_ids,
         
     | 
| 510 | 
         
            +
                                    (0, max_seq_length - len(ele.input_ids)),
         
     | 
| 511 | 
         
            +
                                    value=self.pad_token_id,
         
     | 
| 512 | 
         
            +
                                )
         
     | 
| 513 | 
         
            +
                                for ele in processed_batch
         
     | 
| 514 | 
         
            +
                            ]
         
     | 
| 515 | 
         
            +
                        )
         
     | 
| 516 | 
         
            +
                        if return_labels:
         
     | 
| 517 | 
         
            +
                            label_ids = torch.stack(
         
     | 
| 518 | 
         
            +
                                [
         
     | 
| 519 | 
         
            +
                                    F.pad(
         
     | 
| 520 | 
         
            +
                                        ele.label_ids,
         
     | 
| 521 | 
         
            +
                                        (0, max_seq_length - len(ele.label_ids)),
         
     | 
| 522 | 
         
            +
                                        value=-100,
         
     | 
| 523 | 
         
            +
                                    )
         
     | 
| 524 | 
         
            +
                                    for ele in processed_batch
         
     | 
| 525 | 
         
            +
                                ]
         
     | 
| 526 | 
         
            +
                            )
         
     | 
| 527 | 
         
            +
                        attention_mask = torch.stack(
         
     | 
| 528 | 
         
            +
                            [
         
     | 
| 529 | 
         
            +
                                F.pad(
         
     | 
| 530 | 
         
            +
                                    torch.ones_like(ele.input_ids),
         
     | 
| 531 | 
         
            +
                                    (0, max_seq_length - len(ele.input_ids)),
         
     | 
| 532 | 
         
            +
                                    value=0,
         
     | 
| 533 | 
         
            +
                                )
         
     | 
| 534 | 
         
            +
                                for ele in processed_batch
         
     | 
| 535 | 
         
            +
                            ]
         
     | 
| 536 | 
         
            +
                        )
         
     | 
| 537 | 
         
            +
             
     | 
| 538 | 
         
            +
                    if not self.return_audio_in_tokens:
         
     | 
| 539 | 
         
            +
                        audio_in_ids = None
         
     | 
| 540 | 
         
            +
                        audio_in_ids_start = None
         
     | 
| 541 | 
         
            +
             
     | 
| 542 | 
         
            +
                    # Apply audio_num_codebooks limit if specified
         
     | 
| 543 | 
         
            +
                    if self.audio_num_codebooks is not None:
         
     | 
| 544 | 
         
            +
                        if audio_in_ids is not None:
         
     | 
| 545 | 
         
            +
                            audio_in_ids = audio_in_ids[: self.audio_num_codebooks]
         
     | 
| 546 | 
         
            +
                        if audio_out_ids is not None:
         
     | 
| 547 | 
         
            +
                            audio_out_ids = audio_out_ids[: self.audio_num_codebooks]
         
     | 
| 548 | 
         
            +
                        if label_audio_ids is not None:
         
     | 
| 549 | 
         
            +
                            label_audio_ids = label_audio_ids[: self.audio_num_codebooks]
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
                    return HiggsAudioBatchInput(
         
     | 
| 552 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 553 | 
         
            +
                        attention_mask=attention_mask,
         
     | 
| 554 | 
         
            +
                        audio_features=audio_features,
         
     | 
| 555 | 
         
            +
                        audio_feature_attention_mask=audio_feature_attention_mask,
         
     | 
| 556 | 
         
            +
                        audio_out_ids=audio_out_ids,
         
     | 
| 557 | 
         
            +
                        audio_out_ids_start=audio_out_ids_start,
         
     | 
| 558 | 
         
            +
                        audio_out_ids_start_group_loc=audio_out_ids_start_group_loc,
         
     | 
| 559 | 
         
            +
                        audio_in_ids=audio_in_ids,
         
     | 
| 560 | 
         
            +
                        audio_in_ids_start=audio_in_ids_start,
         
     | 
| 561 | 
         
            +
                        label_ids=label_ids,
         
     | 
| 562 | 
         
            +
                        label_audio_ids=label_audio_ids,
         
     | 
| 563 | 
         
            +
                        reward=reward,
         
     | 
| 564 | 
         
            +
                    )
         
     | 
| 565 | 
         
            +
             
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
            +
            class HiggsAudioDPOSamplesCollator(HiggsAudioSampleCollator):
         
     | 
| 568 | 
         
            +
                def __init__(self, *args, **kwargs):
         
     | 
| 569 | 
         
            +
                    super().__init__(*args, **kwargs)
         
     | 
| 570 | 
         
            +
             
     | 
| 571 | 
         
            +
                def __call__(self, batch: List[RankedChatMLDatasetSampleTuple]) -> HiggsAudioBatchInput:
         
     | 
| 572 | 
         
            +
                    # flatten ranked chatml samples
         
     | 
| 573 | 
         
            +
                    chosen = []
         
     | 
| 574 | 
         
            +
                    rejected = []
         
     | 
| 575 | 
         
            +
             
     | 
| 576 | 
         
            +
                    for sample in batch:
         
     | 
| 577 | 
         
            +
                        chosen.append(sample.max_score_sample())
         
     | 
| 578 | 
         
            +
                        rejected.append(sample.min_score_sample())
         
     | 
| 579 | 
         
            +
             
     | 
| 580 | 
         
            +
                    merged = chosen
         
     | 
| 581 | 
         
            +
                    merged.extend(rejected)
         
     | 
| 582 | 
         
            +
             
     | 
| 583 | 
         
            +
                    return super().__call__(batch=merged)
         
     | 
    	
        higgs_audio/data_types.py
    ADDED
    
    | 
         @@ -0,0 +1,38 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Basic data types for multimodal ChatML format."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 4 | 
         
            +
            from typing import Dict, List, Optional, Union
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            @dataclass
         
     | 
| 8 | 
         
            +
            class AudioContent:
         
     | 
| 9 | 
         
            +
                audio_url: str
         
     | 
| 10 | 
         
            +
                # Base64 encoded audio bytes
         
     | 
| 11 | 
         
            +
                raw_audio: Optional[str] = None
         
     | 
| 12 | 
         
            +
                offset: Optional[float] = None
         
     | 
| 13 | 
         
            +
                duration: Optional[float] = None
         
     | 
| 14 | 
         
            +
                row_id: Optional[int] = None
         
     | 
| 15 | 
         
            +
                type: str = "audio"
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            @dataclass
         
     | 
| 19 | 
         
            +
            class TextContent:
         
     | 
| 20 | 
         
            +
                text: str
         
     | 
| 21 | 
         
            +
                type: str = "text"
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            @dataclass
         
     | 
| 25 | 
         
            +
            class Message:
         
     | 
| 26 | 
         
            +
                role: str
         
     | 
| 27 | 
         
            +
                content: Union[str, AudioContent, TextContent, List[Union[str, AudioContent, TextContent]]]
         
     | 
| 28 | 
         
            +
                recipient: Optional[str] = None
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
            +
            @dataclass
         
     | 
| 32 | 
         
            +
            class ChatMLSample:
         
     | 
| 33 | 
         
            +
                """Dataclass to hold multimodal ChatML data."""
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
                messages: List[Message]
         
     | 
| 36 | 
         
            +
                start_index: Optional[int] = None  # We will mask the messages[:start_index] when finetuning the LLM.
         
     | 
| 37 | 
         
            +
                misc: Optional[Dict] = None
         
     | 
| 38 | 
         
            +
                speaker: Optional[str] = None
         
     | 
    	
        higgs_audio/dataset/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        higgs_audio/dataset/chatml_dataset.py
    ADDED
    
    | 
         @@ -0,0 +1,554 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import dacite
         
     | 
| 2 | 
         
            +
            import pandas as pd
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import json
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            import numpy as np
         
     | 
| 7 | 
         
            +
            import multiprocessing as mp
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from dataclasses import dataclass, fields
         
     | 
| 10 | 
         
            +
            from abc import ABC, abstractmethod
         
     | 
| 11 | 
         
            +
            from typing import Union, List, Dict, Optional
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            from ..data_types import ChatMLSample, TextContent, AudioContent
         
     | 
| 14 | 
         
            +
            from ..constants import AUDIO_IN_TOKEN, AUDIO_OUT_TOKEN
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            from loguru import logger
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
            # Whisper processor, 30 sec -> 3000 features
         
     | 
| 19 | 
         
            +
            # Then we divide 4 in the audio towker, we decrease 3000 features to 750, which gives 25 Hz
         
     | 
| 20 | 
         
            +
            WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC = 25
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            @dataclass
         
     | 
| 24 | 
         
            +
            class ChatMLDatasetSample:
         
     | 
| 25 | 
         
            +
                input_ids: torch.LongTensor  # Shape (seq_len,): The input text tokens.
         
     | 
| 26 | 
         
            +
                label_ids: torch.LongTensor  # Shape (seq_len,): The label ids.
         
     | 
| 27 | 
         
            +
                audio_ids_concat: torch.LongTensor  # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
         
     | 
| 28 | 
         
            +
                # Here `audio_seq_len` is the length of the concatenated audio tokens.`
         
     | 
| 29 | 
         
            +
                audio_ids_start: (
         
     | 
| 30 | 
         
            +
                    torch.LongTensor
         
     | 
| 31 | 
         
            +
                )  # Shape (num_audios,): The start index of each audio token in the concatenated audio tokens.
         
     | 
| 32 | 
         
            +
                audio_waveforms_concat: (
         
     | 
| 33 | 
         
            +
                    torch.Tensor
         
     | 
| 34 | 
         
            +
                )  # Shape (total_wv_length,): The concatenated audio waveforms for audio-in features.
         
     | 
| 35 | 
         
            +
                audio_waveforms_start: (
         
     | 
| 36 | 
         
            +
                    torch.LongTensor
         
     | 
| 37 | 
         
            +
                )  # Shape (num_audios,): The start index of each audio waveform in the concatenated audio waveforms.
         
     | 
| 38 | 
         
            +
                audio_sample_rate: torch.Tensor  # Shape (num_audios,): The sampling rate of the audio waveforms.
         
     | 
| 39 | 
         
            +
                audio_speaker_indices: (
         
     | 
| 40 | 
         
            +
                    torch.LongTensor
         
     | 
| 41 | 
         
            +
                )  # Shape (num_audios,) -1 means unknown speaker: The speaker indices for each audio.
         
     | 
| 42 | 
         
            +
                audio_label_ids_concat: Optional[torch.LongTensor] = (
         
     | 
| 43 | 
         
            +
                    None  # Shape (num_codebooks, audio_seq_len): The audio tokens that are concatenated.
         
     | 
| 44 | 
         
            +
                )
         
     | 
| 45 | 
         
            +
                # Here `audio_seq_len` is the length of the concatenated audio tokens.`
         
     | 
| 46 | 
         
            +
                reward: Optional[float] = None
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                def num_audios(self):
         
     | 
| 49 | 
         
            +
                    return max(len(self.audio_waveforms_start), len(self.audio_ids_start))
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                def get_audio_codes(self, idx):
         
     | 
| 52 | 
         
            +
                    code_start = self.audio_ids_start[idx]
         
     | 
| 53 | 
         
            +
                    if idx < len(self.audio_ids_start) - 1:
         
     | 
| 54 | 
         
            +
                        code_end = self.audio_ids_start[idx + 1]
         
     | 
| 55 | 
         
            +
                    else:
         
     | 
| 56 | 
         
            +
                        code_end = self.audio_ids_concat.shape[-1]
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
                    return self.audio_ids_concat[:, code_start:code_end]
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                def get_audio_codes_labels(self, idx):
         
     | 
| 61 | 
         
            +
                    if self.audio_label_ids_concat is None:
         
     | 
| 62 | 
         
            +
                        return None
         
     | 
| 63 | 
         
            +
                    code_start = self.audio_ids_start[idx]
         
     | 
| 64 | 
         
            +
                    if idx < len(self.audio_ids_start) - 1:
         
     | 
| 65 | 
         
            +
                        code_end = self.audio_ids_start[idx + 1]
         
     | 
| 66 | 
         
            +
                    else:
         
     | 
| 67 | 
         
            +
                        code_end = self.audio_ids_concat.shape[-1]
         
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
                    return self.audio_label_ids_concat[:, code_start:code_end]
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                def get_wv(self, idx):
         
     | 
| 72 | 
         
            +
                    wv_start = self.audio_waveforms_start[idx]
         
     | 
| 73 | 
         
            +
                    sr = self.audio_sample_rate[idx]
         
     | 
| 74 | 
         
            +
                    if idx < len(self.audio_waveforms_start) - 1:
         
     | 
| 75 | 
         
            +
                        wv_end = self.audio_waveforms_start[idx + 1]
         
     | 
| 76 | 
         
            +
                    else:
         
     | 
| 77 | 
         
            +
                        wv_end = self.audio_waveforms_concat.shape[-1]
         
     | 
| 78 | 
         
            +
                    return self.audio_waveforms_concat[wv_start:wv_end], sr
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                def cal_num_tokens(
         
     | 
| 81 | 
         
            +
                    self,
         
     | 
| 82 | 
         
            +
                    encode_whisper_embed: bool = True,
         
     | 
| 83 | 
         
            +
                    encode_audio_in_tokens: bool = False,
         
     | 
| 84 | 
         
            +
                    encode_audio_out_tokens: bool = True,
         
     | 
| 85 | 
         
            +
                    audio_in_token_id: int = 128015,
         
     | 
| 86 | 
         
            +
                    audio_out_token_id: int = 128016,
         
     | 
| 87 | 
         
            +
                ) -> int:
         
     | 
| 88 | 
         
            +
                    # we firstly exclude <|AUDIO|> and <|AUDIO_OUT|> because we do late merging and replace those position with actual audio features and audio token ids
         
     | 
| 89 | 
         
            +
                    # It's assumed that we always have audio_ids when audio_waveforms are there (but not vice-versa)
         
     | 
| 90 | 
         
            +
                    num_tokens = len(self.input_ids) - len(self.audio_ids_start)
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                    if encode_whisper_embed and len(self.audio_waveforms_concat) > 0:
         
     | 
| 93 | 
         
            +
                        audio_lengths = torch.diff(self.audio_waveforms_start)
         
     | 
| 94 | 
         
            +
                        if len(audio_lengths):
         
     | 
| 95 | 
         
            +
                            # Sum before calling .item()
         
     | 
| 96 | 
         
            +
                            num_tokens += (
         
     | 
| 97 | 
         
            +
                                (
         
     | 
| 98 | 
         
            +
                                    np.ceil(WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC * audio_lengths / self.audio_sample_rate[:-1])
         
     | 
| 99 | 
         
            +
                                ).sum()
         
     | 
| 100 | 
         
            +
                            ).item()
         
     | 
| 101 | 
         
            +
                        # add the last audio's token estimation
         
     | 
| 102 | 
         
            +
                        num_tokens += (
         
     | 
| 103 | 
         
            +
                            np.ceil(
         
     | 
| 104 | 
         
            +
                                WHISPER_EMBED_NUM_HIDDEN_STATE_PER_SEC
         
     | 
| 105 | 
         
            +
                                * (self.audio_waveforms_concat.shape[0] - self.audio_waveforms_start[-1])
         
     | 
| 106 | 
         
            +
                                / self.audio_sample_rate[-1]
         
     | 
| 107 | 
         
            +
                            )
         
     | 
| 108 | 
         
            +
                        ).item()
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
                    if self.audio_ids_concat.size(1) > 0:
         
     | 
| 111 | 
         
            +
                        audio_io_ids = self.input_ids[
         
     | 
| 112 | 
         
            +
                            (self.input_ids == audio_in_token_id) | (self.input_ids == audio_out_token_id)
         
     | 
| 113 | 
         
            +
                        ]
         
     | 
| 114 | 
         
            +
                        audio_io_id_lengths = torch.concat(
         
     | 
| 115 | 
         
            +
                            [
         
     | 
| 116 | 
         
            +
                                torch.diff(self.audio_ids_start),
         
     | 
| 117 | 
         
            +
                                torch.tensor([self.audio_ids_concat.shape[-1] - self.audio_ids_start[-1]]),
         
     | 
| 118 | 
         
            +
                            ]
         
     | 
| 119 | 
         
            +
                        )
         
     | 
| 120 | 
         
            +
                        if encode_audio_in_tokens:
         
     | 
| 121 | 
         
            +
                            num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_in_token_id]).item()
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                        if encode_audio_out_tokens:
         
     | 
| 124 | 
         
            +
                            num_tokens += torch.sum(audio_io_id_lengths[audio_io_ids == audio_out_token_id]).item()
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    return int(num_tokens)
         
     | 
| 127 | 
         
            +
             
     | 
| 128 | 
         
            +
                @classmethod
         
     | 
| 129 | 
         
            +
                def merge(
         
     | 
| 130 | 
         
            +
                    cls,
         
     | 
| 131 | 
         
            +
                    samples: List["ChatMLDatasetSample"],
         
     | 
| 132 | 
         
            +
                    eos_token_id: int,
         
     | 
| 133 | 
         
            +
                    ignore_index: int,
         
     | 
| 134 | 
         
            +
                    padding_size: Optional[int] = None,
         
     | 
| 135 | 
         
            +
                ) -> "ChatMLDatasetSample":
         
     | 
| 136 | 
         
            +
                    """Merges a list of ChatMLDatasetSample instances, inserting eos_token_id and ignore_index between them, and adjusting offsets for audio_ids_start and audio_waveforms_start.
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                    Args:
         
     | 
| 139 | 
         
            +
                        samples (List[ChatMLDatasetSample]): List of samples to merge.
         
     | 
| 140 | 
         
            +
                        eos_token_id (int): Tokens to be inserted into input_ids between samples.
         
     | 
| 141 | 
         
            +
                        ignore_index (int): Default label for padding.
         
     | 
| 142 | 
         
            +
                        padding_size (Optional[int]): If provided, pad the sequence to with this length.
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                    Returns:
         
     | 
| 145 | 
         
            +
                        ChatMLDatasetSample: Merged and potentially padded sample.
         
     | 
| 146 | 
         
            +
                    """
         
     | 
| 147 | 
         
            +
                    if not samples:
         
     | 
| 148 | 
         
            +
                        logger.fatal("The samples list is empty and cannot be merged.")
         
     | 
| 149 | 
         
            +
                        raise ValueError("The samples list is empty and cannot be merged.")
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
                    # Initialize empty lists for concatenation
         
     | 
| 152 | 
         
            +
                    input_ids_list = []
         
     | 
| 153 | 
         
            +
                    label_ids_list = []
         
     | 
| 154 | 
         
            +
                    audio_ids_concat_list = []
         
     | 
| 155 | 
         
            +
                    audio_ids_start_list = []
         
     | 
| 156 | 
         
            +
                    audio_waveforms_concat_list = []
         
     | 
| 157 | 
         
            +
                    audio_waveforms_start_list = []
         
     | 
| 158 | 
         
            +
                    audio_sample_rate_list = []
         
     | 
| 159 | 
         
            +
                    audio_speaker_indices_list = []
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
                    # Track offsets
         
     | 
| 162 | 
         
            +
                    audio_ids_offset = 0
         
     | 
| 163 | 
         
            +
                    audio_waveforms_offset = 0
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    for sample in samples:
         
     | 
| 166 | 
         
            +
                        # Add input_ids and label_ids with padding
         
     | 
| 167 | 
         
            +
                        if input_ids_list:
         
     | 
| 168 | 
         
            +
                            input_ids_list.append(torch.tensor([eos_token_id], dtype=torch.long))
         
     | 
| 169 | 
         
            +
                            label_ids_list.append(torch.tensor([ignore_index], dtype=torch.long))
         
     | 
| 170 | 
         
            +
                        input_ids_list.append(sample.input_ids)
         
     | 
| 171 | 
         
            +
                        label_ids_list.append(sample.label_ids)
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
            +
                        # Add audio_ids_concat and handle empty audio ids
         
     | 
| 174 | 
         
            +
                        if sample.audio_ids_concat.size(1) > 0:
         
     | 
| 175 | 
         
            +
                            audio_ids_concat_list.append(sample.audio_ids_concat)
         
     | 
| 176 | 
         
            +
             
     | 
| 177 | 
         
            +
                            # Offset and add audio_ids_start
         
     | 
| 178 | 
         
            +
                            audio_ids_start_list.append(sample.audio_ids_start + audio_ids_offset)
         
     | 
| 179 | 
         
            +
                            audio_ids_offset += sample.audio_ids_concat.size(
         
     | 
| 180 | 
         
            +
                                1
         
     | 
| 181 | 
         
            +
                            )  # (num_codebooks, seq_len): Update offset by audio_seq_len
         
     | 
| 182 | 
         
            +
             
     | 
| 183 | 
         
            +
                        # Add audio_waveforms_concat
         
     | 
| 184 | 
         
            +
                        if sample.audio_waveforms_concat.size(0) > 0:
         
     | 
| 185 | 
         
            +
                            # Check dimensions of the audio waveform to ensure consistency
         
     | 
| 186 | 
         
            +
                            if (
         
     | 
| 187 | 
         
            +
                                audio_waveforms_concat_list
         
     | 
| 188 | 
         
            +
                                and sample.audio_waveforms_concat.dim() != audio_waveforms_concat_list[0].dim()
         
     | 
| 189 | 
         
            +
                            ):
         
     | 
| 190 | 
         
            +
                                logger.warning(
         
     | 
| 191 | 
         
            +
                                    f"Skipping audio waveform with inconsistent dimensions: expected {audio_waveforms_concat_list[0].dim()}D, got {sample.audio_waveforms_concat.dim()}D"
         
     | 
| 192 | 
         
            +
                                )
         
     | 
| 193 | 
         
            +
                                continue
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                            audio_waveforms_concat_list.append(sample.audio_waveforms_concat)
         
     | 
| 196 | 
         
            +
                            audio_waveforms_start_list.append(sample.audio_waveforms_start + audio_waveforms_offset)
         
     | 
| 197 | 
         
            +
                            audio_waveforms_offset += sample.audio_waveforms_concat.size(0)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
                            # Add audio_sample_rate and audio_speaker_indices
         
     | 
| 200 | 
         
            +
                            audio_sample_rate_list.append(sample.audio_sample_rate)
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                        audio_speaker_indices_list.append(sample.audio_speaker_indices)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
                    # Concatenate all tensors
         
     | 
| 205 | 
         
            +
                    input_ids = torch.cat(input_ids_list, dim=0)
         
     | 
| 206 | 
         
            +
                    label_ids = torch.cat(label_ids_list, dim=0)
         
     | 
| 207 | 
         
            +
             
     | 
| 208 | 
         
            +
                    # Apply padding if padding_size is specified
         
     | 
| 209 | 
         
            +
                    if padding_size is not None and padding_size > 0:
         
     | 
| 210 | 
         
            +
                        input_ids = torch.cat(
         
     | 
| 211 | 
         
            +
                            [
         
     | 
| 212 | 
         
            +
                                input_ids,
         
     | 
| 213 | 
         
            +
                                torch.full((padding_size,), eos_token_id, dtype=torch.long),
         
     | 
| 214 | 
         
            +
                            ],
         
     | 
| 215 | 
         
            +
                            dim=0,
         
     | 
| 216 | 
         
            +
                        )
         
     | 
| 217 | 
         
            +
                        label_ids = torch.cat(
         
     | 
| 218 | 
         
            +
                            [
         
     | 
| 219 | 
         
            +
                                label_ids,
         
     | 
| 220 | 
         
            +
                                torch.full((padding_size,), ignore_index, dtype=torch.long),
         
     | 
| 221 | 
         
            +
                            ],
         
     | 
| 222 | 
         
            +
                            dim=0,
         
     | 
| 223 | 
         
            +
                        )
         
     | 
| 224 | 
         
            +
             
     | 
| 225 | 
         
            +
                    # Safely concatenate audio tensors with proper error handling
         
     | 
| 226 | 
         
            +
                    try:
         
     | 
| 227 | 
         
            +
                        audio_ids_concat = torch.cat(audio_ids_concat_list, dim=1) if audio_ids_concat_list else torch.tensor([[]])
         
     | 
| 228 | 
         
            +
                        audio_ids_start = torch.cat(audio_ids_start_list, dim=0) if audio_ids_start_list else torch.tensor([])
         
     | 
| 229 | 
         
            +
             
     | 
| 230 | 
         
            +
                        # Check for dimensional consistency in audio waveforms
         
     | 
| 231 | 
         
            +
                        if audio_waveforms_concat_list:
         
     | 
| 232 | 
         
            +
                            dims = [t.dim() for t in audio_waveforms_concat_list]
         
     | 
| 233 | 
         
            +
                            if not all(d == dims[0] for d in dims):
         
     | 
| 234 | 
         
            +
                                # If dimensions don't match, log warning and filter out the problematic tensors
         
     | 
| 235 | 
         
            +
                                logger.warning(
         
     | 
| 236 | 
         
            +
                                    f"Inconsistent dimensions in audio waveforms: {dims}. Filtering to keep only consistent ones."
         
     | 
| 237 | 
         
            +
                                )
         
     | 
| 238 | 
         
            +
                                expected_dim = max(set(dims), key=dims.count)  # Most common dimension
         
     | 
| 239 | 
         
            +
                                audio_waveforms_concat_list = [t for t in audio_waveforms_concat_list if t.dim() == expected_dim]
         
     | 
| 240 | 
         
            +
             
     | 
| 241 | 
         
            +
                                # Recalculate audio_waveforms_start with the filtered list
         
     | 
| 242 | 
         
            +
                                if audio_waveforms_concat_list:
         
     | 
| 243 | 
         
            +
                                    audio_waveforms_offset = 0
         
     | 
| 244 | 
         
            +
                                    audio_waveforms_start_list = []
         
     | 
| 245 | 
         
            +
                                    for waveform in audio_waveforms_concat_list:
         
     | 
| 246 | 
         
            +
                                        audio_waveforms_start_list.append(torch.tensor([audio_waveforms_offset]))
         
     | 
| 247 | 
         
            +
                                        audio_waveforms_offset += waveform.size(0)
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                        audio_waveforms_concat = (
         
     | 
| 250 | 
         
            +
                            torch.cat(audio_waveforms_concat_list, dim=0) if audio_waveforms_concat_list else torch.tensor([])
         
     | 
| 251 | 
         
            +
                        )
         
     | 
| 252 | 
         
            +
                        audio_waveforms_start = (
         
     | 
| 253 | 
         
            +
                            torch.cat(audio_waveforms_start_list, dim=0) if audio_waveforms_start_list else torch.tensor([])
         
     | 
| 254 | 
         
            +
                        )
         
     | 
| 255 | 
         
            +
                        audio_sample_rate = (
         
     | 
| 256 | 
         
            +
                            torch.cat(audio_sample_rate_list, dim=0) if audio_sample_rate_list else torch.tensor([])
         
     | 
| 257 | 
         
            +
                        )
         
     | 
| 258 | 
         
            +
                        audio_speaker_indices = (
         
     | 
| 259 | 
         
            +
                            torch.cat(audio_speaker_indices_list, dim=0) if audio_speaker_indices_list else torch.tensor([])
         
     | 
| 260 | 
         
            +
                        )
         
     | 
| 261 | 
         
            +
             
     | 
| 262 | 
         
            +
                    except RuntimeError as e:
         
     | 
| 263 | 
         
            +
                        logger.error(f"Error during tensor concatenation: {str(e)}")
         
     | 
| 264 | 
         
            +
                        logger.warning("Falling back to empty audio tensors")
         
     | 
| 265 | 
         
            +
                        # Fall back to empty tensors
         
     | 
| 266 | 
         
            +
                        audio_ids_concat = torch.tensor([[]])
         
     | 
| 267 | 
         
            +
                        audio_ids_start = torch.tensor([])
         
     | 
| 268 | 
         
            +
                        audio_waveforms_concat = torch.tensor([])
         
     | 
| 269 | 
         
            +
                        audio_waveforms_start = torch.tensor([])
         
     | 
| 270 | 
         
            +
                        audio_sample_rate = torch.tensor([])
         
     | 
| 271 | 
         
            +
                        audio_speaker_indices = torch.tensor([])
         
     | 
| 272 | 
         
            +
             
     | 
| 273 | 
         
            +
                    # Create the merged sample
         
     | 
| 274 | 
         
            +
                    merged_sample = cls(
         
     | 
| 275 | 
         
            +
                        input_ids=input_ids,
         
     | 
| 276 | 
         
            +
                        label_ids=label_ids,
         
     | 
| 277 | 
         
            +
                        audio_ids_concat=audio_ids_concat,
         
     | 
| 278 | 
         
            +
                        audio_ids_start=audio_ids_start,
         
     | 
| 279 | 
         
            +
                        audio_waveforms_concat=audio_waveforms_concat,
         
     | 
| 280 | 
         
            +
                        audio_waveforms_start=audio_waveforms_start,
         
     | 
| 281 | 
         
            +
                        audio_sample_rate=audio_sample_rate,
         
     | 
| 282 | 
         
            +
                        audio_speaker_indices=audio_speaker_indices,
         
     | 
| 283 | 
         
            +
                    )
         
     | 
| 284 | 
         
            +
             
     | 
| 285 | 
         
            +
                    return merged_sample
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
             
     | 
| 288 | 
         
            +
            @dataclass
         
     | 
| 289 | 
         
            +
            class RankedChatMLDatasetSampleTuple:
         
     | 
| 290 | 
         
            +
                samples: List[ChatMLDatasetSample]
         
     | 
| 291 | 
         
            +
                scores: List[float]
         
     | 
| 292 | 
         
            +
             
     | 
| 293 | 
         
            +
                def max_score_sample(self) -> ChatMLDatasetSample:
         
     | 
| 294 | 
         
            +
                    idx = self.scores.index(max(self.scores))
         
     | 
| 295 | 
         
            +
                    self.samples[idx].reward = self.scores[idx]
         
     | 
| 296 | 
         
            +
                    return self.samples[idx]
         
     | 
| 297 | 
         
            +
             
     | 
| 298 | 
         
            +
                def min_score_sample(self) -> ChatMLDatasetSample:
         
     | 
| 299 | 
         
            +
                    idx = self.scores.index(min(self.scores))
         
     | 
| 300 | 
         
            +
                    self.samples[idx].reward = self.scores[idx]
         
     | 
| 301 | 
         
            +
                    return self.samples[idx]
         
     | 
| 302 | 
         
            +
             
     | 
| 303 | 
         
            +
             
     | 
| 304 | 
         
            +
            @dataclass
         
     | 
| 305 | 
         
            +
            class ChatMLDatasetStorageSample:
         
     | 
| 306 | 
         
            +
                input_tokens: torch.LongTensor
         
     | 
| 307 | 
         
            +
                label_tokens: torch.LongTensor
         
     | 
| 308 | 
         
            +
                audio_bytes_cache_dir_index: int
         
     | 
| 309 | 
         
            +
                audio_codes_cache_dir_index: int
         
     | 
| 310 | 
         
            +
                audio_bytes_indices: torch.LongTensor
         
     | 
| 311 | 
         
            +
                audio_codes_indices: torch.LongTensor
         
     | 
| 312 | 
         
            +
                speaker_indices: torch.LongTensor
         
     | 
| 313 | 
         
            +
                file_index: int
         
     | 
| 314 | 
         
            +
                original_sample_index: int
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
             
     | 
| 317 | 
         
            +
            # TODO(sxjscience): We need to revist the logic about parsing speaker ids.
         
     | 
| 318 | 
         
            +
            # Currently, we assume that the speaker id is stored at the "misc" field in ChatMLSample.
         
     | 
| 319 | 
         
            +
            def prepare_chatml_sample(sample: Union[ChatMLSample, Dict], tokenizer):
         
     | 
| 320 | 
         
            +
                """Preprocess the ChatML sample to get the tokens for the text part.
         
     | 
| 321 | 
         
            +
             
     | 
| 322 | 
         
            +
                Args:
         
     | 
| 323 | 
         
            +
                    sample (ChatMLSample): The ChatML sample to preprocess.
         
     | 
| 324 | 
         
            +
                    tokenizer: The tokenizer to use for encoding the text.
         
     | 
| 325 | 
         
            +
             
     | 
| 326 | 
         
            +
                """
         
     | 
| 327 | 
         
            +
             
     | 
| 328 | 
         
            +
                try:
         
     | 
| 329 | 
         
            +
                    if not isinstance(sample, ChatMLSample):
         
     | 
| 330 | 
         
            +
                        # Handle all fields that could be NaN
         
     | 
| 331 | 
         
            +
                        if "speaker" in sample and pd.isna(sample["speaker"]):
         
     | 
| 332 | 
         
            +
                            sample["speaker"] = None
         
     | 
| 333 | 
         
            +
                        if "start_index" in sample and pd.isna(sample["start_index"]):
         
     | 
| 334 | 
         
            +
                            sample["start_index"] = None
         
     | 
| 335 | 
         
            +
                        if "content" in sample and pd.isna(sample["content"]):
         
     | 
| 336 | 
         
            +
                            sample["content"] = ""
         
     | 
| 337 | 
         
            +
             
     | 
| 338 | 
         
            +
                        # Convert any other potential NaN values in nested structures
         
     | 
| 339 | 
         
            +
                        def convert_nan_to_none(obj):
         
     | 
| 340 | 
         
            +
                            import numpy as np
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                            if isinstance(obj, (pd.Series, np.ndarray)):
         
     | 
| 343 | 
         
            +
                                return obj.tolist()
         
     | 
| 344 | 
         
            +
                            elif pd.api.types.is_scalar(obj) and pd.isna(obj):
         
     | 
| 345 | 
         
            +
                                return None
         
     | 
| 346 | 
         
            +
                            elif isinstance(obj, dict):
         
     | 
| 347 | 
         
            +
                                return {k: convert_nan_to_none(v) for k, v in obj.items()}
         
     | 
| 348 | 
         
            +
                            elif isinstance(obj, (list, tuple)):  # Fixed: Handle both list and tuple
         
     | 
| 349 | 
         
            +
                                return [convert_nan_to_none(item) for item in obj]
         
     | 
| 350 | 
         
            +
                            return obj
         
     | 
| 351 | 
         
            +
             
     | 
| 352 | 
         
            +
                        # Clean the sample data
         
     | 
| 353 | 
         
            +
                        clean_sample = convert_nan_to_none(sample)
         
     | 
| 354 | 
         
            +
             
     | 
| 355 | 
         
            +
                        val_keys = []
         
     | 
| 356 | 
         
            +
                        for field in fields(ChatMLSample):
         
     | 
| 357 | 
         
            +
                            if field.name in clean_sample:
         
     | 
| 358 | 
         
            +
                                val_keys.append(field.name)
         
     | 
| 359 | 
         
            +
                        clean_sample = {k: clean_sample[k] for k in val_keys}
         
     | 
| 360 | 
         
            +
             
     | 
| 361 | 
         
            +
                        try:
         
     | 
| 362 | 
         
            +
                            sample = dacite.from_dict(
         
     | 
| 363 | 
         
            +
                                data_class=ChatMLSample,
         
     | 
| 364 | 
         
            +
                                data=clean_sample,
         
     | 
| 365 | 
         
            +
                                config=dacite.Config(strict=True, check_types=True),
         
     | 
| 366 | 
         
            +
                            )
         
     | 
| 367 | 
         
            +
                        except Exception as e:
         
     | 
| 368 | 
         
            +
                            print(f"Failed to convert to ChatMLSample: {e}")
         
     | 
| 369 | 
         
            +
                            print(f"Clean sample: {json.dumps(clean_sample, indent=2)}")
         
     | 
| 370 | 
         
            +
                            return None, None, None, None
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
                    input_tokens = []
         
     | 
| 373 | 
         
            +
                    label_tokens = []
         
     | 
| 374 | 
         
            +
                    audio_contents = []
         
     | 
| 375 | 
         
            +
                    speaker_id = None
         
     | 
| 376 | 
         
            +
                    if sample.speaker is not None:
         
     | 
| 377 | 
         
            +
                        speaker_id = sample.speaker
         
     | 
| 378 | 
         
            +
                    elif sample.misc is not None:
         
     | 
| 379 | 
         
            +
                        if "speaker" in sample.misc:
         
     | 
| 380 | 
         
            +
                            speaker_id = sample.misc["speaker"]
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                    total_m = len(sample.messages)
         
     | 
| 383 | 
         
            +
                    for turn_id, message in enumerate(sample.messages):
         
     | 
| 384 | 
         
            +
                        role = message.role
         
     | 
| 385 | 
         
            +
                        recipient = message.recipient
         
     | 
| 386 | 
         
            +
                        content = message.content
         
     | 
| 387 | 
         
            +
                        content_l = []
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                        if isinstance(content, str):
         
     | 
| 390 | 
         
            +
                            content_l.append(TextContent(text=content))
         
     | 
| 391 | 
         
            +
                        elif isinstance(content, TextContent):
         
     | 
| 392 | 
         
            +
                            content_l.append(content)
         
     | 
| 393 | 
         
            +
                        elif isinstance(content, AudioContent):
         
     | 
| 394 | 
         
            +
                            content_l.append(content)
         
     | 
| 395 | 
         
            +
                        elif isinstance(content, list):
         
     | 
| 396 | 
         
            +
                            for ele in content:
         
     | 
| 397 | 
         
            +
                                if isinstance(ele, str):
         
     | 
| 398 | 
         
            +
                                    content_l.append(TextContent(text=ele))
         
     | 
| 399 | 
         
            +
                                else:
         
     | 
| 400 | 
         
            +
                                    content_l.append(ele)
         
     | 
| 401 | 
         
            +
                        if turn_id == 0:
         
     | 
| 402 | 
         
            +
                            prefix = f"<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>\n\n"
         
     | 
| 403 | 
         
            +
                        else:
         
     | 
| 404 | 
         
            +
                            prefix = f"<|start_header_id|>{role}<|end_header_id|>\n\n"
         
     | 
| 405 | 
         
            +
                        eot_postfix = "<|eot_id|>"
         
     | 
| 406 | 
         
            +
                        eom_postfix = "<|eom_id|>"
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                        prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
         
     | 
| 409 | 
         
            +
                        input_tokens.extend(prefix_tokens)
         
     | 
| 410 | 
         
            +
                        label_tokens.extend([-100 for _ in prefix_tokens])
         
     | 
| 411 | 
         
            +
             
     | 
| 412 | 
         
            +
                        if recipient:
         
     | 
| 413 | 
         
            +
                            assert role == "assistant", "Recipient is only available for assistant role."
         
     | 
| 414 | 
         
            +
                            recipient_tokens = tokenizer.encode(f"{recipient}<|recipient|>", add_special_tokens=False)
         
     | 
| 415 | 
         
            +
                            input_tokens.extend(recipient_tokens)
         
     | 
| 416 | 
         
            +
                            label_tokens.extend(recipient_tokens)
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                        for content in content_l:
         
     | 
| 419 | 
         
            +
                            if content.type == "text":
         
     | 
| 420 | 
         
            +
                                text_tokens = tokenizer.encode(content.text, add_special_tokens=False)
         
     | 
| 421 | 
         
            +
                                input_tokens.extend(text_tokens)
         
     | 
| 422 | 
         
            +
                                if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
         
     | 
| 423 | 
         
            +
                                    label_tokens.extend(text_tokens)
         
     | 
| 424 | 
         
            +
                                else:
         
     | 
| 425 | 
         
            +
                                    label_tokens.extend([-100 for _ in text_tokens])
         
     | 
| 426 | 
         
            +
             
     | 
| 427 | 
         
            +
                            elif content.type == "audio":
         
     | 
| 428 | 
         
            +
                                # Generate the text-part of the audio tokens
         
     | 
| 429 | 
         
            +
                                audio_contents.append(content)
         
     | 
| 430 | 
         
            +
                                if role == "user" or role == "system":
         
     | 
| 431 | 
         
            +
                                    # Add the text tokens
         
     | 
| 432 | 
         
            +
                                    text_tokens = tokenizer.encode(
         
     | 
| 433 | 
         
            +
                                        f"<|audio_bos|><|AUDIO|><|audio_eos|>",
         
     | 
| 434 | 
         
            +
                                        add_special_tokens=False,
         
     | 
| 435 | 
         
            +
                                    )
         
     | 
| 436 | 
         
            +
                                    input_tokens.extend(text_tokens)
         
     | 
| 437 | 
         
            +
                                    label_tokens.extend([-100 for _ in text_tokens])
         
     | 
| 438 | 
         
            +
                                elif role == "assistant":
         
     | 
| 439 | 
         
            +
                                    # Add the text tokens for audio-out part.
         
     | 
| 440 | 
         
            +
                                    text_tokens = tokenizer.encode(
         
     | 
| 441 | 
         
            +
                                        f"<|audio_out_bos|><|AUDIO_OUT|><|audio_eos|>",
         
     | 
| 442 | 
         
            +
                                        add_special_tokens=False,
         
     | 
| 443 | 
         
            +
                                    )
         
     | 
| 444 | 
         
            +
                                    input_tokens.extend(text_tokens)
         
     | 
| 445 | 
         
            +
                                    if sample.start_index is None or turn_id >= sample.start_index:
         
     | 
| 446 | 
         
            +
                                        label_tokens.extend(text_tokens)
         
     | 
| 447 | 
         
            +
                                    else:
         
     | 
| 448 | 
         
            +
                                        label_tokens.extend([-100 for _ in text_tokens])
         
     | 
| 449 | 
         
            +
                        next_id = turn_id + 1
         
     | 
| 450 | 
         
            +
                        if role == "assistant" and next_id != total_m and sample.messages[next_id].role == "assistant":
         
     | 
| 451 | 
         
            +
                            postfix_tokens = tokenizer.encode(eom_postfix, add_special_tokens=False)
         
     | 
| 452 | 
         
            +
                            input_tokens.extend(postfix_tokens)
         
     | 
| 453 | 
         
            +
                        else:
         
     | 
| 454 | 
         
            +
                            postfix_tokens = tokenizer.encode(eot_postfix, add_special_tokens=False)
         
     | 
| 455 | 
         
            +
                            input_tokens.extend(postfix_tokens)
         
     | 
| 456 | 
         
            +
                        if role == "assistant" and (sample.start_index is None or turn_id >= sample.start_index):
         
     | 
| 457 | 
         
            +
                            label_tokens.extend(postfix_tokens)
         
     | 
| 458 | 
         
            +
                        else:
         
     | 
| 459 | 
         
            +
                            label_tokens.extend([-100 for _ in postfix_tokens])
         
     | 
| 460 | 
         
            +
             
     | 
| 461 | 
         
            +
                    return input_tokens, label_tokens, audio_contents, speaker_id
         
     | 
| 462 | 
         
            +
             
     | 
| 463 | 
         
            +
                except Exception as e:
         
     | 
| 464 | 
         
            +
                    print(f"Error in prepare_chatml_sample: {str(e)}")
         
     | 
| 465 | 
         
            +
                    print(f"Sample data: {json.dumps(sample, indent=2)}")
         
     | 
| 466 | 
         
            +
                    return None, None, None, None
         
     | 
| 467 | 
         
            +
             
     | 
| 468 | 
         
            +
             
     | 
| 469 | 
         
            +
            def extract_generation_prompt_from_input_tokens(input_tokens, tokenizer):
         
     | 
| 470 | 
         
            +
                """Extract the generation prompt and reference answer from the input tokens.
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
                For example:
         
     | 
| 473 | 
         
            +
             
     | 
| 474 | 
         
            +
                Input Text = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
         
     | 
| 475 | 
         
            +
                What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
         
     | 
| 476 | 
         
            +
                <|start_header_id|>assistant<|end_header_id|>\n\nAt first they went by quick, too quick to even get.<|eot_id|>'
         
     | 
| 477 | 
         
            +
             
     | 
| 478 | 
         
            +
                -->
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                Prompt = '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n
         
     | 
| 481 | 
         
            +
                What words do you hear from the provided audio? Write it down for me.<|audio_bos|><|AUDIO|><|audio_eos|><|eot_id|>
         
     | 
| 482 | 
         
            +
                <|start_header_id|>assistant<|end_header_id|>\n\n',
         
     | 
| 483 | 
         
            +
                Reference = 'At first they went by quick, too quick to even get.'
         
     | 
| 484 | 
         
            +
             
     | 
| 485 | 
         
            +
                Args:
         
     | 
| 486 | 
         
            +
                    input_tokens: The input tokens.
         
     | 
| 487 | 
         
            +
                    audio_contents: The audio contents.
         
     | 
| 488 | 
         
            +
                    tokenizer: The tokenizer to use for decoding the text.
         
     | 
| 489 | 
         
            +
             
     | 
| 490 | 
         
            +
                Returns:
         
     | 
| 491 | 
         
            +
                    prompt_tokens: The tokens for the prompt.
         
     | 
| 492 | 
         
            +
                    reference_answer: The reference answer.
         
     | 
| 493 | 
         
            +
                    num_audios_in_reference: The number of audios in the reference answer.
         
     | 
| 494 | 
         
            +
             
     | 
| 495 | 
         
            +
                """
         
     | 
| 496 | 
         
            +
                input_text = tokenizer.decode(input_tokens)
         
     | 
| 497 | 
         
            +
                generation_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
         
     | 
| 498 | 
         
            +
                postfix = "<|eot_id|>"
         
     | 
| 499 | 
         
            +
                assert generation_prefix in input_text
         
     | 
| 500 | 
         
            +
                generation_prompt_end_loc = input_text.rfind(generation_prefix) + len(generation_prefix)
         
     | 
| 501 | 
         
            +
                generation_prompt = input_text[:generation_prompt_end_loc]
         
     | 
| 502 | 
         
            +
                reference_answer = input_text[generation_prompt_end_loc : input_text.find(postfix, generation_prompt_end_loc)]
         
     | 
| 503 | 
         
            +
                num_audios_in_reference = reference_answer.count(AUDIO_IN_TOKEN) + reference_answer.count(AUDIO_OUT_TOKEN)
         
     | 
| 504 | 
         
            +
                return (
         
     | 
| 505 | 
         
            +
                    tokenizer.encode(generation_prompt, add_special_tokens=False),
         
     | 
| 506 | 
         
            +
                    reference_answer,
         
     | 
| 507 | 
         
            +
                    num_audios_in_reference,
         
     | 
| 508 | 
         
            +
                )
         
     | 
| 509 | 
         
            +
             
     | 
| 510 | 
         
            +
             
     | 
| 511 | 
         
            +
            def prepare_chatml_dataframe_single_process(df, tokenizer):
         
     | 
| 512 | 
         
            +
                """Prepare the ChatML DataFrame."""
         
     | 
| 513 | 
         
            +
                ret = []
         
     | 
| 514 | 
         
            +
                for _, row in df.iterrows():
         
     | 
| 515 | 
         
            +
                    input_tokens, label_tokens, audio_contents, speaker_id = prepare_chatml_sample(row.to_dict(), tokenizer)
         
     | 
| 516 | 
         
            +
                    ret.append((input_tokens, label_tokens, audio_contents, speaker_id))
         
     | 
| 517 | 
         
            +
                return ret
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
            +
            def prepare_chatml_dataframe(df, tokenizer, num_process=16):
         
     | 
| 521 | 
         
            +
                if num_process is None:
         
     | 
| 522 | 
         
            +
                    return prepare_chatml_dataframe_single_process(df, tokenizer)
         
     | 
| 523 | 
         
            +
                else:
         
     | 
| 524 | 
         
            +
                    num_process = max(min(len(df) // 1000, num_process), 1)
         
     | 
| 525 | 
         
            +
                    workloads = np.array_split(df, num_process)
         
     | 
| 526 | 
         
            +
                    with mp.Pool(num_process) as pool:
         
     | 
| 527 | 
         
            +
                        ret = pool.starmap(
         
     | 
| 528 | 
         
            +
                            prepare_chatml_dataframe_single_process,
         
     | 
| 529 | 
         
            +
                            [(workload, tokenizer) for workload in workloads],
         
     | 
| 530 | 
         
            +
                        )
         
     | 
| 531 | 
         
            +
                return sum(ret, [])
         
     | 
| 532 | 
         
            +
             
     | 
| 533 | 
         
            +
             
     | 
| 534 | 
         
            +
            class DatasetInterface(ABC):
         
     | 
| 535 | 
         
            +
                @abstractmethod
         
     | 
| 536 | 
         
            +
                def __getitem__(self, idx) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
         
     | 
| 537 | 
         
            +
                    """Retrieve a dataset sample by index."""
         
     | 
| 538 | 
         
            +
                    raise NotImplementedError
         
     | 
| 539 | 
         
            +
             
     | 
| 540 | 
         
            +
             
     | 
| 541 | 
         
            +
            class IterableDatasetInterface(ABC):
         
     | 
| 542 | 
         
            +
                @abstractmethod
         
     | 
| 543 | 
         
            +
                def __iter__(
         
     | 
| 544 | 
         
            +
                    self,
         
     | 
| 545 | 
         
            +
                ) -> Union["ChatMLDatasetSample", "RankedChatMLDatasetSampleTuple"]:
         
     | 
| 546 | 
         
            +
                    """Retrieve a sample by iterating through the dataset."""
         
     | 
| 547 | 
         
            +
                    raise NotImplementedError
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
             
     | 
| 550 | 
         
            +
            @dataclass
         
     | 
| 551 | 
         
            +
            class DatasetInfo:
         
     | 
| 552 | 
         
            +
                dataset_type: str
         
     | 
| 553 | 
         
            +
                group_type: Optional[str] = None
         
     | 
| 554 | 
         
            +
                mask_text: Optional[bool] = None  # Whether to mask the text tokens for pretraining samples.
         
     | 
    	
        higgs_audio/model/__init__.py
    ADDED
    
    | 
         @@ -0,0 +1,9 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import AutoConfig, AutoModel
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from .configuration_higgs_audio import HiggsAudioConfig, HiggsAudioEncoderConfig
         
     | 
| 4 | 
         
            +
            from .modeling_higgs_audio import HiggsAudioModel
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            AutoConfig.register("higgs_audio_encoder", HiggsAudioEncoderConfig)
         
     | 
| 8 | 
         
            +
            AutoConfig.register("higgs_audio", HiggsAudioConfig)
         
     | 
| 9 | 
         
            +
            AutoModel.register(HiggsAudioConfig, HiggsAudioModel)
         
     | 
    	
        higgs_audio/model/audio_head.py
    ADDED
    
    | 
         @@ -0,0 +1,139 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            """Projector that maps hidden states from the LLM component to multimodal logits."""
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from torch import nn
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 7 | 
         
            +
            from typing import Optional, Tuple
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            from .common import HiggsAudioPreTrainedModel
         
     | 
| 10 | 
         
            +
            from .configuration_higgs_audio import HiggsAudioConfig
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            @dataclass
         
     | 
| 14 | 
         
            +
            class HiggsAudioDecoderLayerOutput:
         
     | 
| 15 | 
         
            +
                logits: torch.FloatTensor
         
     | 
| 16 | 
         
            +
                audio_logits: torch.FloatTensor
         
     | 
| 17 | 
         
            +
                attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
         
     | 
| 18 | 
         
            +
                past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel):
         
     | 
| 22 | 
         
            +
                """Projection layers that map hidden states from the LLM component to audio / text logits.
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                We support two type of audio head:
         
     | 
| 25 | 
         
            +
                - Basic Audio Head:
         
     | 
| 26 | 
         
            +
                    Directly map the hidden states to audio logits for all the codebooks.
         
     | 
| 27 | 
         
            +
                """
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
                def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None):
         
     | 
| 30 | 
         
            +
                    super().__init__(config)
         
     | 
| 31 | 
         
            +
                    self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
         
     | 
| 32 | 
         
            +
                    self.audio_lm_head = nn.Linear(
         
     | 
| 33 | 
         
            +
                        config.text_config.hidden_size,
         
     | 
| 34 | 
         
            +
                        config.audio_num_codebooks * (config.audio_codebook_size + 2),
         
     | 
| 35 | 
         
            +
                        bias=False,
         
     | 
| 36 | 
         
            +
                    )
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    # Initialize weights and apply final processing
         
     | 
| 39 | 
         
            +
                    self.post_init()
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                def forward(
         
     | 
| 42 | 
         
            +
                    self,
         
     | 
| 43 | 
         
            +
                    hidden_states,
         
     | 
| 44 | 
         
            +
                    audio_out_mask,
         
     | 
| 45 | 
         
            +
                    label_audio_ids=None,
         
     | 
| 46 | 
         
            +
                    attention_mask=None,
         
     | 
| 47 | 
         
            +
                    position_ids=None,
         
     | 
| 48 | 
         
            +
                    past_key_values=None,
         
     | 
| 49 | 
         
            +
                    use_cache=None,
         
     | 
| 50 | 
         
            +
                    output_attentions=None,
         
     | 
| 51 | 
         
            +
                    output_hidden_states=None,
         
     | 
| 52 | 
         
            +
                    output_audio_hidden_states=False,
         
     | 
| 53 | 
         
            +
                    cache_position=None,
         
     | 
| 54 | 
         
            +
                ):
         
     | 
| 55 | 
         
            +
                    """
         
     | 
| 56 | 
         
            +
                    Args:
         
     | 
| 57 | 
         
            +
                        hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
         
     | 
| 58 | 
         
            +
                            Hidden states from the LLM component
         
     | 
| 59 | 
         
            +
                        audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
         
     | 
| 60 | 
         
            +
                            Mask for identifying the audio out tokens.
         
     | 
| 61 | 
         
            +
                        label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`):
         
     | 
| 62 | 
         
            +
                            Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used.
         
     | 
| 63 | 
         
            +
                        attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
         
     | 
| 64 | 
         
            +
                            Mask to avoid performing attention on padding token indices
         
     | 
| 65 | 
         
            +
                        position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`):
         
     | 
| 66 | 
         
            +
                            Position ids for the input tokens
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
                    Returns:
         
     | 
| 69 | 
         
            +
                        logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`):
         
     | 
| 70 | 
         
            +
                            Logits for text tokens
         
     | 
| 71 | 
         
            +
                        audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`):
         
     | 
| 72 | 
         
            +
                            Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len`
         
     | 
| 73 | 
         
            +
                    """
         
     | 
| 74 | 
         
            +
                    logits = self.text_lm_head(hidden_states)
         
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
                    all_hidden_states = () if output_hidden_states else None
         
     | 
| 77 | 
         
            +
                    all_self_attns = () if output_attentions else None
         
     | 
| 78 | 
         
            +
                    next_decoder_cache = None
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                    # TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input.
         
     | 
| 81 | 
         
            +
                    if self.config.audio_decoder_proj_num_layers > 0:
         
     | 
| 82 | 
         
            +
                        # create position embeddings to be shared across the decoder layers
         
     | 
| 83 | 
         
            +
                        position_embeddings = self.rotary_emb(hidden_states, position_ids)
         
     | 
| 84 | 
         
            +
                        for decoder_layer in self.transformer_layers:
         
     | 
| 85 | 
         
            +
                            if output_hidden_states:
         
     | 
| 86 | 
         
            +
                                all_hidden_states += (hidden_states,)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                            if self.gradient_checkpointing and self.training:
         
     | 
| 89 | 
         
            +
                                layer_outputs = self._gradient_checkpointing_func(
         
     | 
| 90 | 
         
            +
                                    decoder_layer.__call__,
         
     | 
| 91 | 
         
            +
                                    hidden_states,
         
     | 
| 92 | 
         
            +
                                    attention_mask,
         
     | 
| 93 | 
         
            +
                                    position_ids,
         
     | 
| 94 | 
         
            +
                                    past_key_values,
         
     | 
| 95 | 
         
            +
                                    output_attentions,
         
     | 
| 96 | 
         
            +
                                    use_cache,
         
     | 
| 97 | 
         
            +
                                    cache_position,
         
     | 
| 98 | 
         
            +
                                    position_embeddings,
         
     | 
| 99 | 
         
            +
                                )
         
     | 
| 100 | 
         
            +
                            else:
         
     | 
| 101 | 
         
            +
                                layer_outputs = decoder_layer(
         
     | 
| 102 | 
         
            +
                                    hidden_states,
         
     | 
| 103 | 
         
            +
                                    attention_mask=attention_mask,
         
     | 
| 104 | 
         
            +
                                    position_ids=position_ids,
         
     | 
| 105 | 
         
            +
                                    past_key_value=past_key_values,
         
     | 
| 106 | 
         
            +
                                    output_attentions=output_attentions,
         
     | 
| 107 | 
         
            +
                                    use_cache=use_cache,
         
     | 
| 108 | 
         
            +
                                    cache_position=cache_position,
         
     | 
| 109 | 
         
            +
                                    position_embeddings=position_embeddings,
         
     | 
| 110 | 
         
            +
                                )
         
     | 
| 111 | 
         
            +
                            hidden_states = layer_outputs[0]
         
     | 
| 112 | 
         
            +
                        hidden_states = self.norm(hidden_states)
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
                        if output_hidden_states:
         
     | 
| 115 | 
         
            +
                            all_hidden_states += (hidden_states,)
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                        if output_attentions:
         
     | 
| 118 | 
         
            +
                            all_self_attns += (layer_outputs[1],)
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                        if use_cache:
         
     | 
| 121 | 
         
            +
                            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                    next_cache = next_decoder_cache if use_cache else None
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
                    audio_logits = self.audio_lm_head(hidden_states[audio_out_mask])
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    if output_audio_hidden_states:
         
     | 
| 128 | 
         
            +
                        audio_hidden_states = hidden_states[audio_out_mask]
         
     | 
| 129 | 
         
            +
                    else:
         
     | 
| 130 | 
         
            +
                        audio_hidden_states = None
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    return (
         
     | 
| 133 | 
         
            +
                        logits,
         
     | 
| 134 | 
         
            +
                        audio_logits,
         
     | 
| 135 | 
         
            +
                        all_self_attns,
         
     | 
| 136 | 
         
            +
                        all_hidden_states,
         
     | 
| 137 | 
         
            +
                        audio_hidden_states,
         
     | 
| 138 | 
         
            +
                        next_cache,
         
     | 
| 139 | 
         
            +
                    )
         
     | 
    	
        higgs_audio/model/common.py
    ADDED
    
    | 
         @@ -0,0 +1,27 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from torch import nn
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            from transformers.modeling_utils import PreTrainedModel
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from .configuration_higgs_audio import HiggsAudioConfig
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            class HiggsAudioPreTrainedModel(PreTrainedModel):
         
     | 
| 9 | 
         
            +
                config_class = HiggsAudioConfig
         
     | 
| 10 | 
         
            +
                base_model_prefix = "model"
         
     | 
| 11 | 
         
            +
                supports_gradient_checkpointing = True
         
     | 
| 12 | 
         
            +
                _no_split_modules = []
         
     | 
| 13 | 
         
            +
                _skip_keys_device_placement = "past_key_values"
         
     | 
| 14 | 
         
            +
                _supports_flash_attn_2 = True
         
     | 
| 15 | 
         
            +
                _supports_sdpa = True
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def _init_weights(self, module):
         
     | 
| 18 | 
         
            +
                    std = self.config.init_std if hasattr(self.config, "init_std") else self.config.audio_encoder_config.init_std
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    if isinstance(module, (nn.Linear, nn.Conv1d)):
         
     | 
| 21 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 22 | 
         
            +
                        if module.bias is not None:
         
     | 
| 23 | 
         
            +
                            module.bias.data.zero_()
         
     | 
| 24 | 
         
            +
                    elif isinstance(module, nn.Embedding):
         
     | 
| 25 | 
         
            +
                        module.weight.data.normal_(mean=0.0, std=std)
         
     | 
| 26 | 
         
            +
                        if module.padding_idx is not None:
         
     | 
| 27 | 
         
            +
                            module.weight.data[module.padding_idx].zero_()
         
     | 
    	
        higgs_audio/model/configuration_higgs_audio.py
    ADDED
    
    | 
         @@ -0,0 +1,235 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers.configuration_utils import PretrainedConfig
         
     | 
| 2 | 
         
            +
            from transformers.models.auto import CONFIG_MAPPING
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class HiggsAudioEncoderConfig(PretrainedConfig):
         
     | 
| 6 | 
         
            +
                """Configuration of the Audio encoder in Higgs-Audio."""
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                model_type = "higgs_audio_encoder"
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
                def __init__(
         
     | 
| 11 | 
         
            +
                    self,
         
     | 
| 12 | 
         
            +
                    num_mel_bins=128,
         
     | 
| 13 | 
         
            +
                    encoder_layers=32,
         
     | 
| 14 | 
         
            +
                    encoder_attention_heads=20,
         
     | 
| 15 | 
         
            +
                    encoder_ffn_dim=5120,
         
     | 
| 16 | 
         
            +
                    encoder_layerdrop=0.0,
         
     | 
| 17 | 
         
            +
                    d_model=1280,
         
     | 
| 18 | 
         
            +
                    dropout=0.0,
         
     | 
| 19 | 
         
            +
                    attention_dropout=0.0,
         
     | 
| 20 | 
         
            +
                    activation_function="gelu",
         
     | 
| 21 | 
         
            +
                    activation_dropout=0.0,
         
     | 
| 22 | 
         
            +
                    scale_embedding=False,
         
     | 
| 23 | 
         
            +
                    init_std=0.02,
         
     | 
| 24 | 
         
            +
                    max_source_positions=1500,
         
     | 
| 25 | 
         
            +
                    pad_token_id=128001,
         
     | 
| 26 | 
         
            +
                    **kwargs,
         
     | 
| 27 | 
         
            +
                ):
         
     | 
| 28 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
                    self.num_mel_bins = num_mel_bins
         
     | 
| 31 | 
         
            +
                    self.d_model = d_model
         
     | 
| 32 | 
         
            +
                    self.encoder_layers = encoder_layers
         
     | 
| 33 | 
         
            +
                    self.encoder_attention_heads = encoder_attention_heads
         
     | 
| 34 | 
         
            +
                    self.encoder_ffn_dim = encoder_ffn_dim
         
     | 
| 35 | 
         
            +
                    self.dropout = dropout
         
     | 
| 36 | 
         
            +
                    self.attention_dropout = attention_dropout
         
     | 
| 37 | 
         
            +
                    self.activation_function = activation_function
         
     | 
| 38 | 
         
            +
                    self.activation_dropout = activation_dropout
         
     | 
| 39 | 
         
            +
                    self.encoder_layerdrop = encoder_layerdrop
         
     | 
| 40 | 
         
            +
                    self.num_hidden_layers = encoder_layers
         
     | 
| 41 | 
         
            +
                    self.init_std = init_std
         
     | 
| 42 | 
         
            +
                    self.scale_embedding = scale_embedding  # scale factor will be sqrt(d_model) if True
         
     | 
| 43 | 
         
            +
                    self.max_source_positions = max_source_positions
         
     | 
| 44 | 
         
            +
                    self.pad_token_id = pad_token_id
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            class HiggsAudioConfig(PretrainedConfig):
         
     | 
| 48 | 
         
            +
                r"""
         
     | 
| 49 | 
         
            +
                This is the configuration class for the HiggsAudioModel.
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                Args:
         
     | 
| 52 | 
         
            +
                    text_config (`Union[AutoConfig, dict]`):
         
     | 
| 53 | 
         
            +
                        The config object or dictionary of the text backbone.
         
     | 
| 54 | 
         
            +
                    audio_encoder_config (`Union[AutoConfig, dict]`):
         
     | 
| 55 | 
         
            +
                        The config object or dictionary of the whisper encoder.
         
     | 
| 56 | 
         
            +
                        The audio encoder will be bidirectional and will be only available for audio understanding.
         
     | 
| 57 | 
         
            +
                    audio_tokenizer_config
         
     | 
| 58 | 
         
            +
                        The config object or dictionary of the audio tokenizer.
         
     | 
| 59 | 
         
            +
                    audio_adapter_type
         
     | 
| 60 | 
         
            +
                        The type of audio adapter to use. We support two types of adapter:
         
     | 
| 61 | 
         
            +
                        - stack:
         
     | 
| 62 | 
         
            +
                            We stack additional Transformer layers after the main LLM backbone for audio generation.
         
     | 
| 63 | 
         
            +
                        - dual_ffn:
         
     | 
| 64 | 
         
            +
                            For selected part of the LLM backbone, we replace the text FFN with a dual FFN architecture
         
     | 
| 65 | 
         
            +
                            that contains an additional audio FFN. The audio FFN will be triggered when the location is marked for audio tokens.
         
     | 
| 66 | 
         
            +
                        - dual_ffn_fast_forward:
         
     | 
| 67 | 
         
            +
                            We pick a few layers in the LLM backbone to plug-in the audio FFN. For the remaining layers,
         
     | 
| 68 | 
         
            +
                            the audio hidden states will be directly fast-forward to the next layer.
         
     | 
| 69 | 
         
            +
                            This reduces the computational cost for audio generation.
         
     | 
| 70 | 
         
            +
                    audio_embed_avg (`bool`, *optional*, defaults to False):
         
     | 
| 71 | 
         
            +
                        Whether to average the audio embeddings before sending them to the text attention layer.
         
     | 
| 72 | 
         
            +
                    audio_ffn_hidden_size
         
     | 
| 73 | 
         
            +
                        The hidden size of the audio feedforward network in dual-path FFN
         
     | 
| 74 | 
         
            +
                    audio_ffn_intermediate_size
         
     | 
| 75 | 
         
            +
                        The intermediate size of the audio feedforward network in dual-path FFN
         
     | 
| 76 | 
         
            +
                    audio_dual_ffn_layers
         
     | 
| 77 | 
         
            +
                        The layers in the LLM backbone to plug-in the dual FFN layer (mixture of audio FFN and text FFN).
         
     | 
| 78 | 
         
            +
                    audio_decoder_proj_num_attention (`int`, *optional*, defaults to 0):
         
     | 
| 79 | 
         
            +
                        The number of attention heads in the audio decoder projection layer.
         
     | 
| 80 | 
         
            +
                    use_delay_pattern (`bool`, *optional*, defaults to False):
         
     | 
| 81 | 
         
            +
                        Whether to use delay pattern in the audio decoder.
         
     | 
| 82 | 
         
            +
                    skip_audio_tower (`bool`, *optional*, defaults to False):
         
     | 
| 83 | 
         
            +
                        Whether to skip the audio tower in the audio encoder.
         
     | 
| 84 | 
         
            +
                    use_audio_out_embed_projector (`bool`, *optional*, defaults to False):
         
     | 
| 85 | 
         
            +
                        Whether to use an embedding projector to map audio out embeddings.
         
     | 
| 86 | 
         
            +
                    use_audio_out_self_attention (`bool`, *optional*, defaults to False):
         
     | 
| 87 | 
         
            +
                        Whether to use self-attention to aggregate information from audio-tokens before sending to the text attention layer.
         
     | 
| 88 | 
         
            +
                    audio_num_codebooks (`int`, *optional*, defaults to 12):
         
     | 
| 89 | 
         
            +
                        The number of codebooks in RVQGAN.
         
     | 
| 90 | 
         
            +
                    audio_codebook_size (`int`, *optional*, defaults to 1024):
         
     | 
| 91 | 
         
            +
                        The size of each codebook in RVQGAN.
         
     | 
| 92 | 
         
            +
                    audio_stream_bos_id
         
     | 
| 93 | 
         
            +
                        The id of the bos in the audio stream
         
     | 
| 94 | 
         
            +
                    audio_stream_eos_id
         
     | 
| 95 | 
         
            +
                        The id of the eos in the audio stream
         
     | 
| 96 | 
         
            +
                    audio_bos_token (`str`, *optional*, defaults to "<|audio_bos|>"):
         
     | 
| 97 | 
         
            +
                        The special `<|audio_bos|>` token. In Higgs-Audio, it is mapped to 128011,
         
     | 
| 98 | 
         
            +
                        which is the index of `<|reserved_special_token_3|>` in Llama-3.1-8B-Instruct's tokenizer.
         
     | 
| 99 | 
         
            +
                    audio_eos_token (`str`, *optional*, defaults to "<|audio_eos|>"):
         
     | 
| 100 | 
         
            +
                        The special `<|audio_eos|>` token. We use 128012 as the default value,
         
     | 
| 101 | 
         
            +
                        which is the index of `<|reserved_special_token_4|>` in Llama-3.1-8B-Instruct's tokenizer.
         
     | 
| 102 | 
         
            +
                    audio_out_bos_token (`str`, *optional*, defaults to "<|audio_out_bos|>"):
         
     | 
| 103 | 
         
            +
                        The special `<|audio_out_bos|>` token. We use 128013 as the default value,
         
     | 
| 104 | 
         
            +
                        which is the index of `<|reserved_special_token_5|>` in Llama-3.1-8B-Instruct's tokenizer.
         
     | 
| 105 | 
         
            +
                    audio_token (`str`, *optional*, defaults to "<|AUDIO|>"):
         
     | 
| 106 | 
         
            +
                        The special `<|AUDIO|>` token. We use 128015 as the default value,
         
     | 
| 107 | 
         
            +
                        which is the index of `<|reserved_special_token_7|>` in Llama-3.1-8B-Instruct's tokenizer.
         
     | 
| 108 | 
         
            +
                        This token indicates that the location should be filled in with whisper features.
         
     | 
| 109 | 
         
            +
                    audio_out_token (`str`, *optional*, defaults to "<|AUDIO_OUT|>"):
         
     | 
| 110 | 
         
            +
                        The special `<|AUDIO_OUT|>` token. We use 128016 as the default value,
         
     | 
| 111 | 
         
            +
                        which is the index of `<|reserved_special_token_8|>` in Llama-3.1-8B-Instruct's tokenizer.
         
     | 
| 112 | 
         
            +
                        This token indicates that the location should be filled in with audio tokens extracted via audio tokenizer.
         
     | 
| 113 | 
         
            +
                """
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
                model_type = "higgs_audio"
         
     | 
| 116 | 
         
            +
                is_composition = True
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                def __init__(
         
     | 
| 119 | 
         
            +
                    self,
         
     | 
| 120 | 
         
            +
                    text_config=None,
         
     | 
| 121 | 
         
            +
                    audio_encoder_config=None,
         
     | 
| 122 | 
         
            +
                    audio_tokenizer_config=None,
         
     | 
| 123 | 
         
            +
                    audio_adapter_type="stack",
         
     | 
| 124 | 
         
            +
                    audio_embed_avg=False,
         
     | 
| 125 | 
         
            +
                    audio_ffn_hidden_size=4096,
         
     | 
| 126 | 
         
            +
                    audio_ffn_intermediate_size=14336,
         
     | 
| 127 | 
         
            +
                    audio_dual_ffn_layers=None,
         
     | 
| 128 | 
         
            +
                    audio_decoder_proj_num_layers=0,
         
     | 
| 129 | 
         
            +
                    encode_whisper_embed=True,
         
     | 
| 130 | 
         
            +
                    encode_audio_in_tokens=False,
         
     | 
| 131 | 
         
            +
                    use_delay_pattern=False,
         
     | 
| 132 | 
         
            +
                    skip_audio_tower=False,
         
     | 
| 133 | 
         
            +
                    use_audio_out_embed_projector=False,
         
     | 
| 134 | 
         
            +
                    use_audio_out_self_attention=False,
         
     | 
| 135 | 
         
            +
                    use_rq_transformer=False,
         
     | 
| 136 | 
         
            +
                    rq_transformer_hidden_size=None,
         
     | 
| 137 | 
         
            +
                    rq_transformer_intermediate_size=None,
         
     | 
| 138 | 
         
            +
                    rq_transformer_num_attention_heads=None,
         
     | 
| 139 | 
         
            +
                    rq_transformer_num_key_value_heads=None,
         
     | 
| 140 | 
         
            +
                    rq_transformer_num_hidden_layers=3,
         
     | 
| 141 | 
         
            +
                    audio_num_codebooks=12,
         
     | 
| 142 | 
         
            +
                    audio_codebook_size=1024,
         
     | 
| 143 | 
         
            +
                    audio_stream_bos_id=1024,
         
     | 
| 144 | 
         
            +
                    audio_stream_eos_id=1025,
         
     | 
| 145 | 
         
            +
                    audio_bos_token="<|audio_bos|>",
         
     | 
| 146 | 
         
            +
                    audio_eos_token="<|audio_eos|>",
         
     | 
| 147 | 
         
            +
                    audio_out_bos_token="<|audio_out_bos|>",
         
     | 
| 148 | 
         
            +
                    audio_in_token="<|AUDIO|>",
         
     | 
| 149 | 
         
            +
                    audio_out_token="<|AUDIO_OUT|>",
         
     | 
| 150 | 
         
            +
                    audio_in_token_idx=128015,
         
     | 
| 151 | 
         
            +
                    audio_out_token_idx=128016,
         
     | 
| 152 | 
         
            +
                    pad_token_id=128001,
         
     | 
| 153 | 
         
            +
                    audio_out_bos_token_id=128013,
         
     | 
| 154 | 
         
            +
                    audio_eos_token_id=128012,
         
     | 
| 155 | 
         
            +
                    **kwargs,
         
     | 
| 156 | 
         
            +
                ):
         
     | 
| 157 | 
         
            +
                    if isinstance(audio_encoder_config, dict):
         
     | 
| 158 | 
         
            +
                        audio_encoder_config["model_type"] = (
         
     | 
| 159 | 
         
            +
                            audio_encoder_config["model_type"] if "model_type" in audio_encoder_config else "higgs_audio_encoder"
         
     | 
| 160 | 
         
            +
                        )
         
     | 
| 161 | 
         
            +
                        audio_encoder_config = CONFIG_MAPPING[audio_encoder_config["model_type"]](**audio_encoder_config)
         
     | 
| 162 | 
         
            +
                    elif audio_encoder_config is None:
         
     | 
| 163 | 
         
            +
                        audio_encoder_config = HiggsAudioEncoderConfig()
         
     | 
| 164 | 
         
            +
             
     | 
| 165 | 
         
            +
                    if isinstance(text_config, dict):
         
     | 
| 166 | 
         
            +
                        text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
         
     | 
| 167 | 
         
            +
                        text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
         
     | 
| 168 | 
         
            +
                    elif text_config is None:
         
     | 
| 169 | 
         
            +
                        text_config = CONFIG_MAPPING["llama"]()
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                    assert audio_adapter_type in [
         
     | 
| 172 | 
         
            +
                        "stack",
         
     | 
| 173 | 
         
            +
                        "dual_ffn",
         
     | 
| 174 | 
         
            +
                        "dual_ffn_fast_forward",
         
     | 
| 175 | 
         
            +
                    ], f"Invalid audio adapter type: {audio_adapter_type}"
         
     | 
| 176 | 
         
            +
                    if audio_adapter_type.startswith("dual_ffn"):
         
     | 
| 177 | 
         
            +
                        assert audio_dual_ffn_layers is not None, (
         
     | 
| 178 | 
         
            +
                            "audio_dual_ffn_layers must be specified when using dual_ffn adapter."
         
     | 
| 179 | 
         
            +
                        )
         
     | 
| 180 | 
         
            +
                    self.text_config = text_config
         
     | 
| 181 | 
         
            +
                    self.audio_encoder_config = audio_encoder_config
         
     | 
| 182 | 
         
            +
                    self.audio_tokenizer_config = audio_tokenizer_config
         
     | 
| 183 | 
         
            +
                    self.audio_adapter_type = audio_adapter_type
         
     | 
| 184 | 
         
            +
                    self.audio_embed_avg = audio_embed_avg
         
     | 
| 185 | 
         
            +
                    self.audio_ffn_hidden_size = audio_ffn_hidden_size
         
     | 
| 186 | 
         
            +
                    self.audio_ffn_intermediate_size = audio_ffn_intermediate_size
         
     | 
| 187 | 
         
            +
                    self.audio_dual_ffn_layers = audio_dual_ffn_layers
         
     | 
| 188 | 
         
            +
                    self.audio_decoder_proj_num_layers = audio_decoder_proj_num_layers
         
     | 
| 189 | 
         
            +
                    self.encode_whisper_embed = encode_whisper_embed
         
     | 
| 190 | 
         
            +
                    self.encode_audio_in_tokens = encode_audio_in_tokens
         
     | 
| 191 | 
         
            +
                    self.use_delay_pattern = use_delay_pattern
         
     | 
| 192 | 
         
            +
                    self.skip_audio_tower = skip_audio_tower
         
     | 
| 193 | 
         
            +
                    self.use_audio_out_embed_projector = use_audio_out_embed_projector
         
     | 
| 194 | 
         
            +
                    self.use_audio_out_self_attention = use_audio_out_self_attention
         
     | 
| 195 | 
         
            +
             
     | 
| 196 | 
         
            +
                    self.use_rq_transformer = use_rq_transformer
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                    if self.use_rq_transformer:
         
     | 
| 199 | 
         
            +
                        assert not self.use_delay_pattern, "Delay pattern is not supported if you turned on RQ-Transformer!"
         
     | 
| 200 | 
         
            +
                    self.rq_transformer_hidden_size = rq_transformer_hidden_size
         
     | 
| 201 | 
         
            +
                    self.rq_transformer_intermediate_size = rq_transformer_intermediate_size
         
     | 
| 202 | 
         
            +
                    self.rq_transformer_num_attention_heads = rq_transformer_num_attention_heads
         
     | 
| 203 | 
         
            +
                    self.rq_transformer_num_key_value_heads = rq_transformer_num_key_value_heads
         
     | 
| 204 | 
         
            +
                    self.rq_transformer_num_hidden_layers = rq_transformer_num_hidden_layers
         
     | 
| 205 | 
         
            +
             
     | 
| 206 | 
         
            +
                    if use_rq_transformer:
         
     | 
| 207 | 
         
            +
                        # For RQ-Transformer, we set the hidden_size to the same as the text model's hidden size if it is not specified.
         
     | 
| 208 | 
         
            +
                        if self.rq_transformer_hidden_size is None:
         
     | 
| 209 | 
         
            +
                            self.rq_transformer_hidden_size = text_config.hidden_size
         
     | 
| 210 | 
         
            +
                        assert self.rq_transformer_hidden_size % 128 == 0
         
     | 
| 211 | 
         
            +
                        if self.rq_transformer_intermediate_size is None:
         
     | 
| 212 | 
         
            +
                            self.rq_transformer_intermediate_size = text_config.intermediate_size
         
     | 
| 213 | 
         
            +
                        if self.rq_transformer_num_attention_heads is None:
         
     | 
| 214 | 
         
            +
                            self.rq_transformer_num_attention_heads = self.rq_transformer_hidden_size // 128
         
     | 
| 215 | 
         
            +
                        if self.rq_transformer_num_key_value_heads is None:
         
     | 
| 216 | 
         
            +
                            self.rq_transformer_num_key_value_heads = self.rq_transformer_hidden_size // 128 // 4
         
     | 
| 217 | 
         
            +
                        assert self.rq_transformer_hidden_size % self.rq_transformer_num_attention_heads == 0
         
     | 
| 218 | 
         
            +
                        assert self.rq_transformer_hidden_size % self.rq_transformer_num_key_value_heads == 0
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                    self.audio_num_codebooks = audio_num_codebooks
         
     | 
| 221 | 
         
            +
                    self.audio_codebook_size = audio_codebook_size
         
     | 
| 222 | 
         
            +
                    self.audio_bos_token = audio_bos_token
         
     | 
| 223 | 
         
            +
                    self.audio_eos_token = audio_eos_token
         
     | 
| 224 | 
         
            +
                    self.audio_out_bos_token = audio_out_bos_token
         
     | 
| 225 | 
         
            +
                    self.audio_in_token = audio_in_token
         
     | 
| 226 | 
         
            +
                    self.audio_out_token = audio_out_token
         
     | 
| 227 | 
         
            +
                    self.audio_in_token_idx = audio_in_token_idx
         
     | 
| 228 | 
         
            +
                    self.audio_out_token_idx = audio_out_token_idx
         
     | 
| 229 | 
         
            +
                    self.audio_stream_bos_id = audio_stream_bos_id
         
     | 
| 230 | 
         
            +
                    self.audio_stream_eos_id = audio_stream_eos_id
         
     | 
| 231 | 
         
            +
                    self.audio_out_bos_token_id = audio_out_bos_token_id
         
     | 
| 232 | 
         
            +
                    self.audio_eos_token_id = audio_eos_token_id
         
     | 
| 233 | 
         
            +
             
     | 
| 234 | 
         
            +
                    super().__init__(**kwargs)
         
     | 
| 235 | 
         
            +
                    self.pad_token_id = pad_token_id
         
     | 
    	
        higgs_audio/model/cuda_graph_runner.py
    ADDED
    
    | 
         @@ -0,0 +1,129 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
            from typing import Optional, List, Dict, Tuple, Union
         
     | 
| 4 | 
         
            +
            import gc
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            from transformers.cache_utils import Cache
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            _NUM_WARMUP_ITERS = 2
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            class CUDAGraphRunner(nn.Module):
         
     | 
| 13 | 
         
            +
                def __init__(self, model):
         
     | 
| 14 | 
         
            +
                    super().__init__()
         
     | 
| 15 | 
         
            +
                    self.model = model
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                    self.input_buffers: Dict[str, torch.Tensor] = {}
         
     | 
| 18 | 
         
            +
                    self.output_buffers: Dict[str, torch.Tensor] = {}
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                    self._graph: Optional[torch.cuda.CUDAGraph] = None
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
                @property
         
     | 
| 23 | 
         
            +
                def graph(self):
         
     | 
| 24 | 
         
            +
                    assert self._graph is not None
         
     | 
| 25 | 
         
            +
                    return self._graph
         
     | 
| 26 | 
         
            +
             
     | 
| 27 | 
         
            +
                def capture(
         
     | 
| 28 | 
         
            +
                    self,
         
     | 
| 29 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 30 | 
         
            +
                    causal_mask: torch.Tensor,
         
     | 
| 31 | 
         
            +
                    position_ids: torch.Tensor,
         
     | 
| 32 | 
         
            +
                    audio_discrete_codes_mask: torch.Tensor,
         
     | 
| 33 | 
         
            +
                    cache_position: torch.Tensor,
         
     | 
| 34 | 
         
            +
                    past_key_values: Union[Cache, List[torch.FloatTensor]],
         
     | 
| 35 | 
         
            +
                    use_cache: bool,
         
     | 
| 36 | 
         
            +
                    audio_attention_mask: torch.Tensor,
         
     | 
| 37 | 
         
            +
                    fast_forward_attention_mask: torch.Tensor,
         
     | 
| 38 | 
         
            +
                    output_attentions: bool,
         
     | 
| 39 | 
         
            +
                    output_hidden_states: bool,
         
     | 
| 40 | 
         
            +
                    is_decoding_audio_token: Optional[bool] = None,
         
     | 
| 41 | 
         
            +
                    is_using_cuda_graph: Optional[bool] = False,
         
     | 
| 42 | 
         
            +
                    stream: torch.cuda.Stream = None,
         
     | 
| 43 | 
         
            +
                    memory_pool: Optional[Tuple[int, int]] = None,
         
     | 
| 44 | 
         
            +
                ):
         
     | 
| 45 | 
         
            +
                    assert self._graph is None
         
     | 
| 46 | 
         
            +
                    # Run warmup iterations
         
     | 
| 47 | 
         
            +
                    for _ in range(_NUM_WARMUP_ITERS):
         
     | 
| 48 | 
         
            +
                        self.model(
         
     | 
| 49 | 
         
            +
                            hidden_states=hidden_states,
         
     | 
| 50 | 
         
            +
                            causal_mask=causal_mask,
         
     | 
| 51 | 
         
            +
                            position_ids=position_ids,
         
     | 
| 52 | 
         
            +
                            audio_discrete_codes_mask=audio_discrete_codes_mask,
         
     | 
| 53 | 
         
            +
                            cache_position=cache_position,
         
     | 
| 54 | 
         
            +
                            past_key_values=past_key_values,
         
     | 
| 55 | 
         
            +
                            use_cache=use_cache,
         
     | 
| 56 | 
         
            +
                            audio_attention_mask=audio_attention_mask,
         
     | 
| 57 | 
         
            +
                            fast_forward_attention_mask=fast_forward_attention_mask,
         
     | 
| 58 | 
         
            +
                            output_attentions=output_attentions,
         
     | 
| 59 | 
         
            +
                            output_hidden_states=output_hidden_states,
         
     | 
| 60 | 
         
            +
                            is_decoding_audio_token=is_decoding_audio_token,
         
     | 
| 61 | 
         
            +
                            is_using_cuda_graph=is_using_cuda_graph,
         
     | 
| 62 | 
         
            +
                        )
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    torch.cuda.synchronize()
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    # Capture the graph
         
     | 
| 67 | 
         
            +
                    self._graph = torch.cuda.CUDAGraph()
         
     | 
| 68 | 
         
            +
                    with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
         
     | 
| 69 | 
         
            +
                        out_hidden_states, all_hidden_states, all_self_attns = self.model(
         
     | 
| 70 | 
         
            +
                            hidden_states=hidden_states,
         
     | 
| 71 | 
         
            +
                            causal_mask=causal_mask,
         
     | 
| 72 | 
         
            +
                            position_ids=position_ids,
         
     | 
| 73 | 
         
            +
                            audio_discrete_codes_mask=audio_discrete_codes_mask,
         
     | 
| 74 | 
         
            +
                            cache_position=cache_position,
         
     | 
| 75 | 
         
            +
                            past_key_values=past_key_values,
         
     | 
| 76 | 
         
            +
                            use_cache=use_cache,
         
     | 
| 77 | 
         
            +
                            audio_attention_mask=audio_attention_mask,
         
     | 
| 78 | 
         
            +
                            fast_forward_attention_mask=fast_forward_attention_mask,
         
     | 
| 79 | 
         
            +
                            output_attentions=output_attentions,
         
     | 
| 80 | 
         
            +
                            output_hidden_states=output_hidden_states,
         
     | 
| 81 | 
         
            +
                            is_decoding_audio_token=is_decoding_audio_token,
         
     | 
| 82 | 
         
            +
                            is_using_cuda_graph=is_using_cuda_graph,
         
     | 
| 83 | 
         
            +
                        )
         
     | 
| 84 | 
         
            +
                        # hidden_states_out = torch.ops._C.weak_ref_tensor(outputs[0])
         
     | 
| 85 | 
         
            +
                        # del outputs
         
     | 
| 86 | 
         
            +
                        gc.collect()
         
     | 
| 87 | 
         
            +
                    torch.cuda.synchronize()
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    # Save input and output buffers
         
     | 
| 90 | 
         
            +
                    self.input_buffers = {
         
     | 
| 91 | 
         
            +
                        "hidden_states": hidden_states,
         
     | 
| 92 | 
         
            +
                        "causal_mask": causal_mask,
         
     | 
| 93 | 
         
            +
                        "position_ids": position_ids,
         
     | 
| 94 | 
         
            +
                        "audio_discrete_codes_mask": audio_discrete_codes_mask,
         
     | 
| 95 | 
         
            +
                        "cache_position": cache_position,
         
     | 
| 96 | 
         
            +
                        "past_key_values": past_key_values,
         
     | 
| 97 | 
         
            +
                        "audio_attention_mask": audio_attention_mask,
         
     | 
| 98 | 
         
            +
                        "fast_forward_attention_mask": fast_forward_attention_mask,
         
     | 
| 99 | 
         
            +
                    }
         
     | 
| 100 | 
         
            +
                    self.output_buffers = {
         
     | 
| 101 | 
         
            +
                        "hidden_states": out_hidden_states,
         
     | 
| 102 | 
         
            +
                        "all_hidden_states": all_hidden_states,
         
     | 
| 103 | 
         
            +
                        "all_self_attns": all_self_attns,
         
     | 
| 104 | 
         
            +
                    }
         
     | 
| 105 | 
         
            +
             
     | 
| 106 | 
         
            +
                def forward(
         
     | 
| 107 | 
         
            +
                    self,
         
     | 
| 108 | 
         
            +
                    hidden_states: torch.Tensor,
         
     | 
| 109 | 
         
            +
                    causal_mask: torch.Tensor,
         
     | 
| 110 | 
         
            +
                    position_ids: torch.Tensor,
         
     | 
| 111 | 
         
            +
                    audio_discrete_codes_mask: torch.Tensor,
         
     | 
| 112 | 
         
            +
                    cache_position: torch.Tensor,
         
     | 
| 113 | 
         
            +
                    audio_attention_mask: torch.Tensor,
         
     | 
| 114 | 
         
            +
                    fast_forward_attention_mask: torch.Tensor,
         
     | 
| 115 | 
         
            +
                    **kwargs,
         
     | 
| 116 | 
         
            +
                ) -> torch.Tensor:
         
     | 
| 117 | 
         
            +
                    # Copy input tensors to buffers
         
     | 
| 118 | 
         
            +
                    self.input_buffers["hidden_states"].copy_(hidden_states, non_blocking=True)
         
     | 
| 119 | 
         
            +
                    self.input_buffers["causal_mask"].copy_(causal_mask, non_blocking=True)
         
     | 
| 120 | 
         
            +
                    self.input_buffers["position_ids"].copy_(position_ids, non_blocking=True)
         
     | 
| 121 | 
         
            +
                    self.input_buffers["audio_discrete_codes_mask"].copy_(audio_discrete_codes_mask, non_blocking=True)
         
     | 
| 122 | 
         
            +
                    self.input_buffers["cache_position"].copy_(cache_position, non_blocking=True)
         
     | 
| 123 | 
         
            +
                    self.input_buffers["audio_attention_mask"].copy_(audio_attention_mask, non_blocking=True)
         
     | 
| 124 | 
         
            +
                    self.input_buffers["fast_forward_attention_mask"].copy_(fast_forward_attention_mask, non_blocking=True)
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                    # Run the captured graph
         
     | 
| 127 | 
         
            +
                    self.graph.replay()
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
            +
                    return self.output_buffers["hidden_states"], None, None
         
     | 
    	
        higgs_audio/model/custom_modules.py
    ADDED
    
    | 
         @@ -0,0 +1,155 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import torch
         
     | 
| 2 | 
         
            +
            import torch.nn as nn
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            class PartiallyFrozenEmbedding(nn.Module):
         
     | 
| 6 | 
         
            +
                """Split an existing `nn.Embedding` module that splits the embedding into:
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
                - A frozen embedding for indices [0..freeze_until_idx].
         
     | 
| 9 | 
         
            +
                - A trainable embedding for indices [freeze_until_idx+1..vocab_size-1].
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
                This should work with both Zero-2 and Zero-3 seamlessly
         
     | 
| 12 | 
         
            +
                """
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def __init__(self, original_embedding: nn.Embedding, freeze_until_idx: int):
         
     | 
| 15 | 
         
            +
                    """
         
     | 
| 16 | 
         
            +
                    :param original_embedding: An instance of nn.Embedding (the original embedding layer).
         
     | 
| 17 | 
         
            +
                    :param freeze_until_idx: The index up to which the embedding is frozen (excluding). The freeze_until_idx is not frozen.
         
     | 
| 18 | 
         
            +
                    """
         
     | 
| 19 | 
         
            +
                    super().__init__()
         
     | 
| 20 | 
         
            +
                    self.freeze_until_idx = freeze_until_idx
         
     | 
| 21 | 
         
            +
                    self.original_vocab_size = original_embedding.num_embeddings
         
     | 
| 22 | 
         
            +
                    self.embedding_dim = original_embedding.embedding_dim
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                    # Split the original embedding into frozen and trainable parts
         
     | 
| 25 | 
         
            +
                    self.embedding_frozen = nn.Embedding(
         
     | 
| 26 | 
         
            +
                        freeze_until_idx,
         
     | 
| 27 | 
         
            +
                        self.embedding_dim,
         
     | 
| 28 | 
         
            +
                        dtype=original_embedding.weight.dtype,
         
     | 
| 29 | 
         
            +
                        device=original_embedding.weight.device,
         
     | 
| 30 | 
         
            +
                    )
         
     | 
| 31 | 
         
            +
                    self.embedding_trainable = nn.Embedding(
         
     | 
| 32 | 
         
            +
                        self.original_vocab_size - freeze_until_idx,
         
     | 
| 33 | 
         
            +
                        self.embedding_dim,
         
     | 
| 34 | 
         
            +
                        dtype=original_embedding.weight.dtype,
         
     | 
| 35 | 
         
            +
                        device=original_embedding.weight.device,
         
     | 
| 36 | 
         
            +
                    )
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                    # Copy weights from the original embedding into the frozen and trainable parts
         
     | 
| 39 | 
         
            +
                    with torch.no_grad():
         
     | 
| 40 | 
         
            +
                        self.embedding_frozen.weight.copy_(original_embedding.weight[:freeze_until_idx])
         
     | 
| 41 | 
         
            +
                        self.embedding_trainable.weight.copy_(original_embedding.weight[freeze_until_idx:])
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    # Freeze the frozen embedding
         
     | 
| 44 | 
         
            +
                    self.embedding_frozen.weight.requires_grad = False
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
         
     | 
| 47 | 
         
            +
                    """
         
     | 
| 48 | 
         
            +
                    Forward pass for the split embedding wrapper.
         
     | 
| 49 | 
         
            +
                    :param input_ids: Tensor of shape [batch_size, seq_len] with indices in [0..original_vocab_size-1].
         
     | 
| 50 | 
         
            +
                    """
         
     | 
| 51 | 
         
            +
                    # Masks to separate frozen and trainable indices
         
     | 
| 52 | 
         
            +
                    # (bsz, seq_len)
         
     | 
| 53 | 
         
            +
                    mask_frozen = input_ids < self.freeze_until_idx
         
     | 
| 54 | 
         
            +
                    mask_trainable = ~mask_frozen
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    # Output tensor for embedding results
         
     | 
| 57 | 
         
            +
                    batch_size, seq_len = input_ids.shape
         
     | 
| 58 | 
         
            +
                    embeddings = torch.zeros(
         
     | 
| 59 | 
         
            +
                        batch_size,
         
     | 
| 60 | 
         
            +
                        seq_len,
         
     | 
| 61 | 
         
            +
                        self.embedding_dim,
         
     | 
| 62 | 
         
            +
                        device=input_ids.device,
         
     | 
| 63 | 
         
            +
                        dtype=self.embedding_frozen.weight.dtype,
         
     | 
| 64 | 
         
            +
                    )
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                    # Handle frozen embedding
         
     | 
| 67 | 
         
            +
                    if mask_frozen.any():
         
     | 
| 68 | 
         
            +
                        frozen_ids = input_ids[mask_frozen]
         
     | 
| 69 | 
         
            +
                        frozen_emb = self.embedding_frozen(frozen_ids)
         
     | 
| 70 | 
         
            +
                        embeddings[mask_frozen] = frozen_emb
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                    # Handle trainable embedding
         
     | 
| 73 | 
         
            +
                    if mask_trainable.any():
         
     | 
| 74 | 
         
            +
                        # Adjust trainable IDs to the local index space of the trainable embedding
         
     | 
| 75 | 
         
            +
                        trainable_ids = input_ids[mask_trainable] - (self.freeze_until_idx)
         
     | 
| 76 | 
         
            +
                        trainable_emb = self.embedding_trainable(trainable_ids)
         
     | 
| 77 | 
         
            +
                        embeddings[mask_trainable] = trainable_emb
         
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
                    return embeddings
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                def to_unsplit(self) -> nn.Embedding:
         
     | 
| 82 | 
         
            +
                    unsplit_embedding = nn.Embedding(
         
     | 
| 83 | 
         
            +
                        self.original_vocab_size,
         
     | 
| 84 | 
         
            +
                        self.embedding_dim,
         
     | 
| 85 | 
         
            +
                        dtype=self.embedding_frozen.weight.dtype,
         
     | 
| 86 | 
         
            +
                        device=self.embedding_frozen.weight.device,
         
     | 
| 87 | 
         
            +
                    )
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                    with torch.no_grad():
         
     | 
| 90 | 
         
            +
                        unsplit_embedding.weight[: self.freeze_until_idx].copy_(self.embedding_frozen.weight)
         
     | 
| 91 | 
         
            +
                        unsplit_embedding.weight[self.freeze_until_idx :].copy_(self.embedding_trainable.weight)
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
                    return unsplit_embedding
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
             
     | 
| 96 | 
         
            +
            class PartiallyFrozenLinear(nn.Module):
         
     | 
| 97 | 
         
            +
                """A wrapper around nn.Linear to partially freeze part of the weight matrix."""
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                def __init__(self, original_linear: nn.Linear, freeze_until_idx: int):
         
     | 
| 100 | 
         
            +
                    """
         
     | 
| 101 | 
         
            +
                    :param original_linear: The original nn.Linear layer.
         
     | 
| 102 | 
         
            +
                    :param freeze_until_idx: The index up to which the rows of the weight matrix are frozen.
         
     | 
| 103 | 
         
            +
                    """
         
     | 
| 104 | 
         
            +
                    super().__init__()
         
     | 
| 105 | 
         
            +
                    assert original_linear.bias is None, "Currently only support linear module without bias"
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                    self.freeze_until_idx = freeze_until_idx
         
     | 
| 108 | 
         
            +
                    self.input_dim = original_linear.in_features
         
     | 
| 109 | 
         
            +
                    self.output_dim = original_linear.out_features
         
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
                    # Create frozen and trainable linear layers
         
     | 
| 112 | 
         
            +
                    self.linear_frozen = nn.Linear(
         
     | 
| 113 | 
         
            +
                        self.input_dim,
         
     | 
| 114 | 
         
            +
                        freeze_until_idx,
         
     | 
| 115 | 
         
            +
                        bias=False,
         
     | 
| 116 | 
         
            +
                        dtype=original_linear.weight.dtype,
         
     | 
| 117 | 
         
            +
                        device=original_linear.weight.device,
         
     | 
| 118 | 
         
            +
                    )
         
     | 
| 119 | 
         
            +
                    self.linear_trainable = nn.Linear(
         
     | 
| 120 | 
         
            +
                        self.input_dim,
         
     | 
| 121 | 
         
            +
                        self.output_dim - freeze_until_idx,
         
     | 
| 122 | 
         
            +
                        bias=False,
         
     | 
| 123 | 
         
            +
                        dtype=original_linear.weight.dtype,
         
     | 
| 124 | 
         
            +
                        device=original_linear.weight.device,
         
     | 
| 125 | 
         
            +
                    )
         
     | 
| 126 | 
         
            +
             
     | 
| 127 | 
         
            +
                    # Copy weights from the original linear layer
         
     | 
| 128 | 
         
            +
                    with torch.no_grad():
         
     | 
| 129 | 
         
            +
                        self.linear_frozen.weight.copy_(original_linear.weight[:freeze_until_idx])
         
     | 
| 130 | 
         
            +
                        self.linear_trainable.weight.copy_(original_linear.weight[freeze_until_idx:])
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                    # Freeze the frozen linear layer
         
     | 
| 133 | 
         
            +
                    self.linear_frozen.weight.requires_grad = False
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
                def forward(self, input_tensor):
         
     | 
| 136 | 
         
            +
                    # input_tensor: (bsz, seq_len, hidden_state_dim)
         
     | 
| 137 | 
         
            +
                    frozen_output = self.linear_frozen(input_tensor)
         
     | 
| 138 | 
         
            +
                    trainable_output = self.linear_trainable(input_tensor)
         
     | 
| 139 | 
         
            +
                    return torch.cat((frozen_output, trainable_output), dim=-1)
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                def to_unsplit(self) -> nn.Linear:
         
     | 
| 142 | 
         
            +
                    unsplit_linear = nn.Linear(
         
     | 
| 143 | 
         
            +
                        self.input_dim,
         
     | 
| 144 | 
         
            +
                        self.output_dim,
         
     | 
| 145 | 
         
            +
                        bias=False,
         
     | 
| 146 | 
         
            +
                        dtype=self.linear_frozen.weight.dtype,
         
     | 
| 147 | 
         
            +
                        device=self.linear_frozen.weight.device,
         
     | 
| 148 | 
         
            +
                    )
         
     | 
| 149 | 
         
            +
             
     | 
| 150 | 
         
            +
                    # Copy weights from the frozen and trainable layers into the unsplit linear layer
         
     | 
| 151 | 
         
            +
                    with torch.no_grad():
         
     | 
| 152 | 
         
            +
                        unsplit_linear.weight[: self.freeze_until_idx].copy_(self.linear_frozen.weight)
         
     | 
| 153 | 
         
            +
                        unsplit_linear.weight[self.freeze_until_idx :].copy_(self.linear_trainable.weight)
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
                    return unsplit_linear
         
     | 
    	
        higgs_audio/model/modeling_higgs_audio.py
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        higgs_audio/model/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,778 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import contextlib
         
     | 
| 2 | 
         
            +
            from contextlib import contextmanager
         
     | 
| 3 | 
         
            +
            from functools import wraps
         
     | 
| 4 | 
         
            +
            import torch
         
     | 
| 5 | 
         
            +
            from transformers.integrations import is_deepspeed_available
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            if is_deepspeed_available():
         
     | 
| 8 | 
         
            +
                from deepspeed.utils import groups as deepspeed_groups
         
     | 
| 9 | 
         
            +
                from deepspeed.sequence.layer import _SeqAllToAll
         
     | 
| 10 | 
         
            +
            else:
         
     | 
| 11 | 
         
            +
                deepspeed_groups = None
         
     | 
| 12 | 
         
            +
                _SeqAllToAll = None
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def _ceil_to_nearest(n, round_to):
         
     | 
| 16 | 
         
            +
                return (n + round_to - 1) // round_to * round_to
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            def count_parameters(model, trainable_only=True):
         
     | 
| 20 | 
         
            +
                if trainable_only:
         
     | 
| 21 | 
         
            +
                    return sum(p.numel() for p in model.parameters() if p.requires_grad)
         
     | 
| 22 | 
         
            +
                else:
         
     | 
| 23 | 
         
            +
                    return sum(p.numel() for p in model.parameters())
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            # TODO(sxjscience) Consider to move the function to audio_processing/utils.py
         
     | 
| 27 | 
         
            +
            def build_delay_pattern_mask(
         
     | 
| 28 | 
         
            +
                input_ids: torch.LongTensor,
         
     | 
| 29 | 
         
            +
                bos_token_id: int,
         
     | 
| 30 | 
         
            +
                pad_token_id: int,
         
     | 
| 31 | 
         
            +
            ):
         
     | 
| 32 | 
         
            +
                """Implement the delay pattern proposed in "Simple and Controllable Music Generation", https://arxiv.org/pdf/2306.05284
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                In the delay pattern, each codebook is offset by the previous codebook by
         
     | 
| 35 | 
         
            +
                one. We insert a special delay token at the start of the sequence if its delayed, and append pad token once the sequence finishes.
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                Take the example where there are 4 codebooks and audio sequence length=5. After shifting, the output should have length seq_len + num_codebooks - 1
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                - [ *,  *,  *,  *,  *,  P,  P,  P]
         
     | 
| 40 | 
         
            +
                - [ B,  *,  *,  *,  *,  *,  P,  P]
         
     | 
| 41 | 
         
            +
                - [ B,  B,  *,  *,  *,  *,  *,  P]
         
     | 
| 42 | 
         
            +
                - [ B,  B,  B,  *,  *,  *,  *,  *]
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                where B indicates the delay token id, P is the special padding token id and `*` indicates that the original audio token.
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
                Now let's consider the case where we have a sequence of audio tokens to condition on.
         
     | 
| 47 | 
         
            +
                The audio tokens were originally in the following non-delayed form:
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                - [a, b]
         
     | 
| 50 | 
         
            +
                - [c, d]
         
     | 
| 51 | 
         
            +
                - [e, f]
         
     | 
| 52 | 
         
            +
                - [g, h]
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                After conversion, we get the following delayed form:
         
     | 
| 55 | 
         
            +
                - [a, b, -1, -1, -1]
         
     | 
| 56 | 
         
            +
                - [B, c,  d, -1, -1]
         
     | 
| 57 | 
         
            +
                - [B, B,  e,  f, -1]
         
     | 
| 58 | 
         
            +
                - [B, B,  B,  g,  h]
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                Note that we have a special token `-1` that indicates it should be replaced by a new token we see in the generation phase.
         
     | 
| 61 | 
         
            +
                In that case, we should override the `-1` tokens in auto-regressive generation.
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
                Args:
         
     | 
| 64 | 
         
            +
                    input_ids (:obj:`torch.LongTensor`):
         
     | 
| 65 | 
         
            +
                        The input ids of the prompt. It will have shape (bsz, num_codebooks, seq_len).
         
     | 
| 66 | 
         
            +
                    bos_token_id (:obj:`int`):
         
     | 
| 67 | 
         
            +
                        The id of the special delay token
         
     | 
| 68 | 
         
            +
                    pad_token_id (:obj:`int`):
         
     | 
| 69 | 
         
            +
                        The id of the padding token. Should be the same as eos_token_id.
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                Returns:
         
     | 
| 72 | 
         
            +
                    input_ids (:obj:`torch.LongTensor`):
         
     | 
| 73 | 
         
            +
                        The transformed input ids with delay pattern applied. It will have shape (bsz, num_codebooks, seq_len + num_codebooks - 1).
         
     | 
| 74 | 
         
            +
                    input_ids_with_gen_mask (:obj:`torch.LongTensor`):
         
     | 
| 75 | 
         
            +
                        The transformed input ids with delay pattern applied. The -1 in the output indicates new tokens that should be generated.
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
                """
         
     | 
| 78 | 
         
            +
                bsz, num_codebooks, seq_len = input_ids.shape
         
     | 
| 79 | 
         
            +
             
     | 
| 80 | 
         
            +
                new_seq_len = seq_len + num_codebooks - 1
         
     | 
| 81 | 
         
            +
                input_ids_with_gen_mask = torch.ones((bsz, num_codebooks, new_seq_len), dtype=torch.long, device=input_ids.device)
         
     | 
| 82 | 
         
            +
                bos_mask = torch.tril(input_ids_with_gen_mask, -1) > 0
         
     | 
| 83 | 
         
            +
                eos_mask = torch.triu(input_ids_with_gen_mask, seq_len) > 0
         
     | 
| 84 | 
         
            +
                input_ids_with_gen_mask[bos_mask] = bos_token_id
         
     | 
| 85 | 
         
            +
                input_ids_with_gen_mask[(~bos_mask) & (~eos_mask)] = input_ids.reshape(-1)
         
     | 
| 86 | 
         
            +
                input_ids = input_ids_with_gen_mask.clone()
         
     | 
| 87 | 
         
            +
                input_ids[eos_mask] = pad_token_id
         
     | 
| 88 | 
         
            +
                input_ids_with_gen_mask[eos_mask] = -1
         
     | 
| 89 | 
         
            +
                return input_ids, input_ids_with_gen_mask
         
     | 
| 90 | 
         
            +
             
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
            def revert_delay_pattern(data):
         
     | 
| 93 | 
         
            +
                """Convert samples encoded with delay pattern back to the original form.
         
     | 
| 94 | 
         
            +
             
     | 
| 95 | 
         
            +
                Args:
         
     | 
| 96 | 
         
            +
                    data (:obj:`torch.Tensor`):
         
     | 
| 97 | 
         
            +
                        The data with delay pattern applied. It will have shape (num_codebooks, seq_len + num_codebooks - 1).
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
                Returns:
         
     | 
| 100 | 
         
            +
                    ret (:obj:`torch.Tensor`):
         
     | 
| 101 | 
         
            +
                        Recovered data with delay pattern removed. It will have shape (num_codebooks, seq_len).
         
     | 
| 102 | 
         
            +
                """
         
     | 
| 103 | 
         
            +
                assert len(data.shape) == 2
         
     | 
| 104 | 
         
            +
                out_l = []
         
     | 
| 105 | 
         
            +
                num_codebooks = data.shape[0]
         
     | 
| 106 | 
         
            +
                for i in range(num_codebooks):
         
     | 
| 107 | 
         
            +
                    out_l.append(data[i : (i + 1), i : (data.shape[1] - num_codebooks + 1 + i)])
         
     | 
| 108 | 
         
            +
                return torch.cat(out_l, dim=0)
         
     | 
| 109 | 
         
            +
             
     | 
| 110 | 
         
            +
             
     | 
| 111 | 
         
            +
            def merge_input_ids_with_audio_features(
         
     | 
| 112 | 
         
            +
                audio_features_embed,
         
     | 
| 113 | 
         
            +
                audio_features_length,
         
     | 
| 114 | 
         
            +
                audio_in_embed,
         
     | 
| 115 | 
         
            +
                audio_in_ids_start,
         
     | 
| 116 | 
         
            +
                audio_out_embed,
         
     | 
| 117 | 
         
            +
                audio_out_ids_start,
         
     | 
| 118 | 
         
            +
                audio_in_token_idx,
         
     | 
| 119 | 
         
            +
                audio_out_token_idx,
         
     | 
| 120 | 
         
            +
                inputs_embeds,
         
     | 
| 121 | 
         
            +
                input_ids,
         
     | 
| 122 | 
         
            +
                attention_mask,
         
     | 
| 123 | 
         
            +
                label_ids,
         
     | 
| 124 | 
         
            +
                pad_token_id,
         
     | 
| 125 | 
         
            +
                ignore_index=-100,
         
     | 
| 126 | 
         
            +
                round_to=8,
         
     | 
| 127 | 
         
            +
                left_padding=True,
         
     | 
| 128 | 
         
            +
            ):
         
     | 
| 129 | 
         
            +
                """
         
     | 
| 130 | 
         
            +
                Merge input_ids with audio features into final embeddings.
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                Args:
         
     | 
| 133 | 
         
            +
                    audio_features_embed (`torch.Tensor` of shape `(num_audios, max_audio_tokens, embed_dim)`):
         
     | 
| 134 | 
         
            +
                        Encoded vectors of all audios in the batch (obtained from the semantic encoder)
         
     | 
| 135 | 
         
            +
                    audio_features_length (`torch.LongTensor` of shape `(num_audios,)`):
         
     | 
| 136 | 
         
            +
                        The length of audio embeddings of each audio as stacked in `audio_features_embed`
         
     | 
| 137 | 
         
            +
                    audio_in_embed (`torch.Tensor` of shape `(total_num_audio_in_tokens, embed_dim)`):
         
     | 
| 138 | 
         
            +
                        The embeddings of audio-in tokens
         
     | 
| 139 | 
         
            +
                    audio_in_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
         
     | 
| 140 | 
         
            +
                        The start index of the audio-in tokens for each audio
         
     | 
| 141 | 
         
            +
                    audio_out_embed (`torch.Tensor` of shape `(total_num_audio_out_tokens, embed_dim)`):
         
     | 
| 142 | 
         
            +
                        The embeddings of audio-out tokens
         
     | 
| 143 | 
         
            +
                    audio_out_ids_start (`torch.LongTensor` of shape `(num_audios,)`):
         
     | 
| 144 | 
         
            +
                        The start index of the audio-out tokens for each audio
         
     | 
| 145 | 
         
            +
                    audio_in_token_idx
         
     | 
| 146 | 
         
            +
                        The index of the audio-in token in the vocabulary
         
     | 
| 147 | 
         
            +
                    audio_out_token_idx
         
     | 
| 148 | 
         
            +
                        The index of the audio-out token in the vocabulary
         
     | 
| 149 | 
         
            +
                    inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
         
     | 
| 150 | 
         
            +
                        Token embeddings before merging with audio embeddings
         
     | 
| 151 | 
         
            +
                    input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
         
     | 
| 152 | 
         
            +
                        Input_ids of tokens, possibly filled with audio token
         
     | 
| 153 | 
         
            +
                    attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
         
     | 
| 154 | 
         
            +
                        Mask to avoid performing attention on padding token indices.
         
     | 
| 155 | 
         
            +
                    label_ids (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
         
     | 
| 156 | 
         
            +
                        labels need to be recalculated to support training (if provided)
         
     | 
| 157 | 
         
            +
                    pad_token_id (`int`):
         
     | 
| 158 | 
         
            +
                        The index of the pad token in the vocabulary
         
     | 
| 159 | 
         
            +
                    ignore_index
         
     | 
| 160 | 
         
            +
                        The index to ignore in the loss calculation
         
     | 
| 161 | 
         
            +
                    round_to
         
     | 
| 162 | 
         
            +
                        The number to round to for padding
         
     | 
| 163 | 
         
            +
                    left_padding
         
     | 
| 164 | 
         
            +
                        Whether to apply left padding
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                Returns:
         
     | 
| 167 | 
         
            +
                    final_embedding
         
     | 
| 168 | 
         
            +
                        The final embeddings after merging audio embeddings with text embeddings.
         
     | 
| 169 | 
         
            +
                    final_attention_mask
         
     | 
| 170 | 
         
            +
                        The final attention mask after merging audio embeddings with text embeddings.
         
     | 
| 171 | 
         
            +
                    final_labels
         
     | 
| 172 | 
         
            +
                        The labels for the text stream
         
     | 
| 173 | 
         
            +
                    position_ids
         
     | 
| 174 | 
         
            +
                        Positional ids for the merged data
         
     | 
| 175 | 
         
            +
                    final_input_ids
         
     | 
| 176 | 
         
            +
                        The final input_ids after merging audio embeddings with text embeddings.
         
     | 
| 177 | 
         
            +
                    final_audio_in_mask
         
     | 
| 178 | 
         
            +
                        Mask for audio-in embeddings
         
     | 
| 179 | 
         
            +
                    final_audio_in_discrete_codes_mask
         
     | 
| 180 | 
         
            +
                        Mask for audio-in discrete tokens
         
     | 
| 181 | 
         
            +
                    final_audio_out_mask
         
     | 
| 182 | 
         
            +
                        Mask for audio-out embeddings
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                Explanation:
         
     | 
| 185 | 
         
            +
                    each audio has variable length embeddings, with length specified by
         
     | 
| 186 | 
         
            +
                    - audio_features_length
         
     | 
| 187 | 
         
            +
                    - audio_in_ids_start
         
     | 
| 188 | 
         
            +
                    - audio_out_ids_start
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    Task:
         
     | 
| 191 | 
         
            +
                    - fill each <|AUDIO|> with audio embeddings (it can be the combination of embeddings extracted by WhisperEncoder and embeddings from audio codebooks)
         
     | 
| 192 | 
         
            +
                    - fill each <|AUDIO_OUT|> with the audio-out embeddings
         
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
                    Example:
         
     | 
| 195 | 
         
            +
                        <|AUDIO_OUT|>: X (5 tokens), Y (3 tokens)
         
     | 
| 196 | 
         
            +
                        <|AUDIO|>: Z (8 tokens)
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                        X, Y are in the same sequence (in-context voice-clone). Z is in a different sequence (audio understanding).
         
     | 
| 199 | 
         
            +
                    if right padding
         
     | 
| 200 | 
         
            +
                        input_ids: [
         
     | 
| 201 | 
         
            +
                            a b c d e f X g h i j k Y l m
         
     | 
| 202 | 
         
            +
                            o p q r Z s t u v _ _ _ _ _ _
         
     | 
| 203 | 
         
            +
                        ]
         
     | 
| 204 | 
         
            +
                        input_ids should be: [
         
     | 
| 205 | 
         
            +
                            a b c d e f X X X X X g h i j k Y Y Y l m
         
     | 
| 206 | 
         
            +
                            o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
         
     | 
| 207 | 
         
            +
                        ]
         
     | 
| 208 | 
         
            +
                        labels should be: [
         
     | 
| 209 | 
         
            +
                            a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
         
     | 
| 210 | 
         
            +
                            o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
         
     | 
| 211 | 
         
            +
                        ]
         
     | 
| 212 | 
         
            +
                    elif left padding
         
     | 
| 213 | 
         
            +
                        input_ids: [
         
     | 
| 214 | 
         
            +
                            a b c d e f X g h i j k Y l m
         
     | 
| 215 | 
         
            +
                            _ _ _ _ _ _ o p q r Z s t u v
         
     | 
| 216 | 
         
            +
                        ]
         
     | 
| 217 | 
         
            +
                        input_ids should be: [
         
     | 
| 218 | 
         
            +
                            a b c d e f X X X X X g h i j k Y Y Y l m
         
     | 
| 219 | 
         
            +
                            _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
         
     | 
| 220 | 
         
            +
                        ]
         
     | 
| 221 | 
         
            +
                        labels should be: [
         
     | 
| 222 | 
         
            +
                            a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
         
     | 
| 223 | 
         
            +
                            _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
         
     | 
| 224 | 
         
            +
                        ]
         
     | 
| 225 | 
         
            +
             
     | 
| 226 | 
         
            +
                """
         
     | 
| 227 | 
         
            +
                if label_ids is None:
         
     | 
| 228 | 
         
            +
                    skip_labels = True
         
     | 
| 229 | 
         
            +
                else:
         
     | 
| 230 | 
         
            +
                    skip_labels = False
         
     | 
| 231 | 
         
            +
                if audio_features_embed is not None and audio_features_embed.shape[0] == 0:
         
     | 
| 232 | 
         
            +
                    audio_features_embed = None
         
     | 
| 233 | 
         
            +
                if audio_in_embed is not None and audio_in_embed.shape[0] == 0:
         
     | 
| 234 | 
         
            +
                    audio_in_embed = None
         
     | 
| 235 | 
         
            +
                if audio_out_embed is not None and audio_out_embed.shape[0] == 0:
         
     | 
| 236 | 
         
            +
                    audio_out_embed = None
         
     | 
| 237 | 
         
            +
             
     | 
| 238 | 
         
            +
                batch_size, sequence_length, embed_dim = inputs_embeds.shape
         
     | 
| 239 | 
         
            +
             
     | 
| 240 | 
         
            +
                target_device = inputs_embeds.device
         
     | 
| 241 | 
         
            +
                if left_padding is None:
         
     | 
| 242 | 
         
            +
                    left_padding = torch.any(attention_mask[:, 0] == 0)
         
     | 
| 243 | 
         
            +
             
     | 
| 244 | 
         
            +
                audio_in_token_mask = input_ids == audio_in_token_idx
         
     | 
| 245 | 
         
            +
                audio_out_token_mask = input_ids == audio_out_token_idx
         
     | 
| 246 | 
         
            +
                text_token_mask = (input_ids != audio_in_token_idx) & (input_ids != audio_out_token_idx)
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
            +
                # 1. Calculate the number of tokens for each placeholder (like [<|AUDIO|>, <|AUDIO_OUT|>]).
         
     | 
| 249 | 
         
            +
                token_placeholder_num = torch.ones_like(input_ids)
         
     | 
| 250 | 
         
            +
             
     | 
| 251 | 
         
            +
                if audio_features_embed is not None:
         
     | 
| 252 | 
         
            +
                    num_audios, max_audio_tokens, _ = audio_features_embed.shape
         
     | 
| 253 | 
         
            +
                    audio_in_features_mask = torch.arange(max_audio_tokens).expand(num_audios, max_audio_tokens).to(
         
     | 
| 254 | 
         
            +
                        audio_features_length.device
         
     | 
| 255 | 
         
            +
                    ) < audio_features_length.unsqueeze(1)
         
     | 
| 256 | 
         
            +
                    masked_audio_in_features = audio_features_embed[audio_in_features_mask].view(-1, embed_dim)
         
     | 
| 257 | 
         
            +
                    token_placeholder_num[audio_in_token_mask] = audio_features_length.long()
         
     | 
| 258 | 
         
            +
             
     | 
| 259 | 
         
            +
                if audio_in_embed is not None:
         
     | 
| 260 | 
         
            +
                    audio_in_codes_length = torch.concat(
         
     | 
| 261 | 
         
            +
                        [
         
     | 
| 262 | 
         
            +
                            audio_in_ids_start[1:] - audio_in_ids_start[:-1],
         
     | 
| 263 | 
         
            +
                            torch.tensor(
         
     | 
| 264 | 
         
            +
                                [audio_in_embed.shape[0] - audio_in_ids_start[-1]],
         
     | 
| 265 | 
         
            +
                                device=audio_in_ids_start.device,
         
     | 
| 266 | 
         
            +
                                dtype=torch.long,
         
     | 
| 267 | 
         
            +
                            ),
         
     | 
| 268 | 
         
            +
                        ],
         
     | 
| 269 | 
         
            +
                        dim=0,
         
     | 
| 270 | 
         
            +
                    )
         
     | 
| 271 | 
         
            +
                    if audio_features_embed is not None:
         
     | 
| 272 | 
         
            +
                        token_placeholder_num[audio_in_token_mask] += audio_in_codes_length.long()
         
     | 
| 273 | 
         
            +
                    else:
         
     | 
| 274 | 
         
            +
                        token_placeholder_num[audio_in_token_mask] = audio_in_codes_length.long()
         
     | 
| 275 | 
         
            +
             
     | 
| 276 | 
         
            +
                if audio_out_embed is not None:
         
     | 
| 277 | 
         
            +
                    audio_out_codes_length = torch.concat(
         
     | 
| 278 | 
         
            +
                        [
         
     | 
| 279 | 
         
            +
                            audio_out_ids_start[1:] - audio_out_ids_start[:-1],
         
     | 
| 280 | 
         
            +
                            torch.tensor(
         
     | 
| 281 | 
         
            +
                                [audio_out_embed.shape[0] - audio_out_ids_start[-1]],
         
     | 
| 282 | 
         
            +
                                device=audio_out_ids_start.device,
         
     | 
| 283 | 
         
            +
                                dtype=torch.long,
         
     | 
| 284 | 
         
            +
                            ),
         
     | 
| 285 | 
         
            +
                        ],
         
     | 
| 286 | 
         
            +
                        dim=0,
         
     | 
| 287 | 
         
            +
                    )
         
     | 
| 288 | 
         
            +
                    token_placeholder_num[audio_out_token_mask] = audio_out_codes_length.long()
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                new_token_positions = torch.cumsum(token_placeholder_num, -1) - 1
         
     | 
| 291 | 
         
            +
                max_token_num = _ceil_to_nearest(token_placeholder_num.sum(-1).max(), round_to)
         
     | 
| 292 | 
         
            +
                nb_audio_pad = max_token_num - 1 - new_token_positions[:, -1]
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
            +
                if left_padding:
         
     | 
| 295 | 
         
            +
                    new_token_positions += nb_audio_pad[:, None]  # offset for left padding
         
     | 
| 296 | 
         
            +
             
     | 
| 297 | 
         
            +
                # 2. Create the full embedding, already padded to the maximum position
         
     | 
| 298 | 
         
            +
                final_embedding = torch.zeros(
         
     | 
| 299 | 
         
            +
                    (batch_size, max_token_num, embed_dim),
         
     | 
| 300 | 
         
            +
                    dtype=inputs_embeds.dtype,
         
     | 
| 301 | 
         
            +
                    device=inputs_embeds.device,
         
     | 
| 302 | 
         
            +
                )
         
     | 
| 303 | 
         
            +
                final_attention_mask = torch.zeros(
         
     | 
| 304 | 
         
            +
                    (batch_size, max_token_num),
         
     | 
| 305 | 
         
            +
                    dtype=attention_mask.dtype,
         
     | 
| 306 | 
         
            +
                    device=inputs_embeds.device,
         
     | 
| 307 | 
         
            +
                )
         
     | 
| 308 | 
         
            +
                final_input_ids = torch.full(
         
     | 
| 309 | 
         
            +
                    (batch_size, max_token_num),
         
     | 
| 310 | 
         
            +
                    pad_token_id,
         
     | 
| 311 | 
         
            +
                    dtype=input_ids.dtype,
         
     | 
| 312 | 
         
            +
                    device=inputs_embeds.device,
         
     | 
| 313 | 
         
            +
                )
         
     | 
| 314 | 
         
            +
                if skip_labels:
         
     | 
| 315 | 
         
            +
                    final_labels = None
         
     | 
| 316 | 
         
            +
                else:
         
     | 
| 317 | 
         
            +
                    final_labels = torch.full(
         
     | 
| 318 | 
         
            +
                        (batch_size, max_token_num),
         
     | 
| 319 | 
         
            +
                        ignore_index,
         
     | 
| 320 | 
         
            +
                        dtype=label_ids.dtype,
         
     | 
| 321 | 
         
            +
                        device=inputs_embeds.device,
         
     | 
| 322 | 
         
            +
                    )
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                final_audio_in_mask = torch.full(
         
     | 
| 325 | 
         
            +
                    (batch_size, max_token_num),
         
     | 
| 326 | 
         
            +
                    False,
         
     | 
| 327 | 
         
            +
                    dtype=torch.bool,
         
     | 
| 328 | 
         
            +
                    device=inputs_embeds.device,
         
     | 
| 329 | 
         
            +
                )
         
     | 
| 330 | 
         
            +
                final_audio_in_discrete_codes_mask = torch.full(
         
     | 
| 331 | 
         
            +
                    (batch_size, max_token_num),
         
     | 
| 332 | 
         
            +
                    False,
         
     | 
| 333 | 
         
            +
                    dtype=torch.bool,
         
     | 
| 334 | 
         
            +
                    device=inputs_embeds.device,
         
     | 
| 335 | 
         
            +
                )
         
     | 
| 336 | 
         
            +
                final_audio_out_mask = torch.full(
         
     | 
| 337 | 
         
            +
                    (batch_size, max_token_num),
         
     | 
| 338 | 
         
            +
                    False,
         
     | 
| 339 | 
         
            +
                    dtype=torch.bool,
         
     | 
| 340 | 
         
            +
                    device=inputs_embeds.device,
         
     | 
| 341 | 
         
            +
                )
         
     | 
| 342 | 
         
            +
                # 3. Get the audio-in token positions and audio-out token positions
         
     | 
| 343 | 
         
            +
                batch_id = torch.arange(batch_size, device=target_device).unsqueeze(1).expand(batch_size, sequence_length)
         
     | 
| 344 | 
         
            +
                audio_in_batch_id = batch_id[audio_in_token_mask]  # Shape (num_audio_in,)
         
     | 
| 345 | 
         
            +
                audio_out_batch_id = batch_id[audio_out_token_mask]  # Shape (num_audio_out,)
         
     | 
| 346 | 
         
            +
                audio_features_token_ends = new_token_positions[audio_in_token_mask]  # Shape (num_audio_in,)
         
     | 
| 347 | 
         
            +
                audio_out_embed_ends = new_token_positions[audio_out_token_mask]  # Shape (num_audio_out,)
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                if audio_in_embed is not None:
         
     | 
| 350 | 
         
            +
                    # Fill in the audio-in embeddings
         
     | 
| 351 | 
         
            +
                    seq_indices = (
         
     | 
| 352 | 
         
            +
                        torch.arange(max_token_num, device=target_device)
         
     | 
| 353 | 
         
            +
                        .unsqueeze(0)
         
     | 
| 354 | 
         
            +
                        .expand(audio_in_ids_start.shape[0], max_token_num)
         
     | 
| 355 | 
         
            +
                    )
         
     | 
| 356 | 
         
            +
                    audio_in_embed_token_starts = audio_features_token_ends - audio_in_codes_length + 1
         
     | 
| 357 | 
         
            +
                    batch_indices, col_indices = torch.where(
         
     | 
| 358 | 
         
            +
                        (seq_indices >= audio_in_embed_token_starts.unsqueeze(1))
         
     | 
| 359 | 
         
            +
                        & (seq_indices <= audio_features_token_ends.unsqueeze(1))
         
     | 
| 360 | 
         
            +
                    )
         
     | 
| 361 | 
         
            +
                    batch_indices = audio_in_batch_id[batch_indices]
         
     | 
| 362 | 
         
            +
                    final_embedding[batch_indices, col_indices] = audio_in_embed
         
     | 
| 363 | 
         
            +
                    final_input_ids[batch_indices, col_indices] = audio_in_token_idx
         
     | 
| 364 | 
         
            +
                    if not skip_labels:
         
     | 
| 365 | 
         
            +
                        final_labels[batch_indices, col_indices] = ignore_index
         
     | 
| 366 | 
         
            +
                    final_audio_in_mask[batch_indices, col_indices] = True
         
     | 
| 367 | 
         
            +
                    final_audio_in_discrete_codes_mask[batch_indices, col_indices] = True
         
     | 
| 368 | 
         
            +
                    audio_features_token_ends = audio_features_token_ends - audio_in_codes_length
         
     | 
| 369 | 
         
            +
             
     | 
| 370 | 
         
            +
                if audio_features_embed is not None:
         
     | 
| 371 | 
         
            +
                    # Fill in the audio features
         
     | 
| 372 | 
         
            +
                    seq_indices = (
         
     | 
| 373 | 
         
            +
                        torch.arange(max_token_num, device=target_device)
         
     | 
| 374 | 
         
            +
                        .unsqueeze(0)
         
     | 
| 375 | 
         
            +
                        .expand(audio_features_embed.shape[0], max_token_num)
         
     | 
| 376 | 
         
            +
                    )
         
     | 
| 377 | 
         
            +
                    audio_features_token_starts = audio_features_token_ends - audio_features_length + 1
         
     | 
| 378 | 
         
            +
                    batch_indices, col_indices = torch.where(
         
     | 
| 379 | 
         
            +
                        (seq_indices >= audio_features_token_starts.unsqueeze(1))
         
     | 
| 380 | 
         
            +
                        & (seq_indices <= audio_features_token_ends.unsqueeze(1))
         
     | 
| 381 | 
         
            +
                    )
         
     | 
| 382 | 
         
            +
                    batch_indices = audio_in_batch_id[batch_indices]
         
     | 
| 383 | 
         
            +
                    final_embedding[batch_indices, col_indices] = masked_audio_in_features
         
     | 
| 384 | 
         
            +
                    final_input_ids[batch_indices, col_indices] = audio_in_token_idx
         
     | 
| 385 | 
         
            +
                    if not skip_labels:
         
     | 
| 386 | 
         
            +
                        final_labels[batch_indices, col_indices] = ignore_index
         
     | 
| 387 | 
         
            +
                    final_audio_in_mask[batch_indices, col_indices] = True
         
     | 
| 388 | 
         
            +
             
     | 
| 389 | 
         
            +
                if audio_out_embed is not None:
         
     | 
| 390 | 
         
            +
                    # Fill in the audio-out embeddings
         
     | 
| 391 | 
         
            +
                    seq_indices = (
         
     | 
| 392 | 
         
            +
                        torch.arange(max_token_num, device=target_device)
         
     | 
| 393 | 
         
            +
                        .unsqueeze(0)
         
     | 
| 394 | 
         
            +
                        .expand(audio_out_ids_start.shape[0], max_token_num)
         
     | 
| 395 | 
         
            +
                    )
         
     | 
| 396 | 
         
            +
                    audio_out_embed_token_starts = audio_out_embed_ends - audio_out_codes_length + 1
         
     | 
| 397 | 
         
            +
                    batch_indices, col_indices = torch.where(
         
     | 
| 398 | 
         
            +
                        (seq_indices >= audio_out_embed_token_starts.unsqueeze(1))
         
     | 
| 399 | 
         
            +
                        & (seq_indices <= audio_out_embed_ends.unsqueeze(1))
         
     | 
| 400 | 
         
            +
                    )
         
     | 
| 401 | 
         
            +
                    batch_indices = audio_out_batch_id[batch_indices]
         
     | 
| 402 | 
         
            +
                    final_embedding[batch_indices, col_indices] = audio_out_embed
         
     | 
| 403 | 
         
            +
                    final_input_ids[batch_indices, col_indices] = audio_out_token_idx
         
     | 
| 404 | 
         
            +
                    if not skip_labels:
         
     | 
| 405 | 
         
            +
                        final_labels[batch_indices, col_indices] = ignore_index
         
     | 
| 406 | 
         
            +
                    final_audio_out_mask[batch_indices, col_indices] = True
         
     | 
| 407 | 
         
            +
             
     | 
| 408 | 
         
            +
                # Fill in the original text embeddings and labels
         
     | 
| 409 | 
         
            +
                batch_indices, non_audio_indices = torch.where(text_token_mask)
         
     | 
| 410 | 
         
            +
                text_to_overwrite = new_token_positions[batch_indices, non_audio_indices]
         
     | 
| 411 | 
         
            +
                final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_audio_indices]
         
     | 
| 412 | 
         
            +
                if not skip_labels:
         
     | 
| 413 | 
         
            +
                    final_labels[batch_indices, text_to_overwrite] = label_ids[batch_indices, non_audio_indices]
         
     | 
| 414 | 
         
            +
                final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_audio_indices]
         
     | 
| 415 | 
         
            +
                final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_audio_indices]
         
     | 
| 416 | 
         
            +
                final_attention_mask = final_attention_mask | final_audio_in_mask | final_audio_out_mask
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                # Trim the tensor if there are redundant padding tokens
         
     | 
| 419 | 
         
            +
                if left_padding:
         
     | 
| 420 | 
         
            +
                    first_non_zero_loc = final_attention_mask.sum(0).nonzero()[0]
         
     | 
| 421 | 
         
            +
                    first_non_zero_loc = (first_non_zero_loc // round_to) * round_to
         
     | 
| 422 | 
         
            +
                    if first_non_zero_loc > 0:
         
     | 
| 423 | 
         
            +
                        final_attention_mask = final_attention_mask[:, first_non_zero_loc:]
         
     | 
| 424 | 
         
            +
                        final_embedding = final_embedding[:, first_non_zero_loc:]
         
     | 
| 425 | 
         
            +
                        if not skip_labels:
         
     | 
| 426 | 
         
            +
                            final_labels = final_labels[:, first_non_zero_loc:]
         
     | 
| 427 | 
         
            +
                        final_input_ids = final_input_ids[:, first_non_zero_loc:]
         
     | 
| 428 | 
         
            +
                        final_audio_in_mask = final_audio_in_mask[:, first_non_zero_loc:]
         
     | 
| 429 | 
         
            +
                        final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, first_non_zero_loc:]
         
     | 
| 430 | 
         
            +
                        final_audio_out_mask = final_audio_out_mask[:, first_non_zero_loc:]
         
     | 
| 431 | 
         
            +
                else:
         
     | 
| 432 | 
         
            +
                    # We have done right padding, so we need to trim the mask
         
     | 
| 433 | 
         
            +
                    last_non_zero_loc = final_attention_mask.sum(0).nonzero()[-1] + 1
         
     | 
| 434 | 
         
            +
                    last_non_zero_loc = ((last_non_zero_loc + round_to - 1) // round_to) * round_to
         
     | 
| 435 | 
         
            +
                    if last_non_zero_loc < max_token_num:
         
     | 
| 436 | 
         
            +
                        final_attention_mask = final_attention_mask[:, :last_non_zero_loc]
         
     | 
| 437 | 
         
            +
                        final_embedding = final_embedding[:, :last_non_zero_loc]
         
     | 
| 438 | 
         
            +
                        if not skip_labels:
         
     | 
| 439 | 
         
            +
                            final_labels = final_labels[:, :last_non_zero_loc]
         
     | 
| 440 | 
         
            +
                        final_input_ids = final_input_ids[:, :last_non_zero_loc]
         
     | 
| 441 | 
         
            +
                        final_audio_in_mask = final_audio_in_mask[:, :last_non_zero_loc]
         
     | 
| 442 | 
         
            +
                        final_audio_in_discrete_codes_mask = final_audio_in_discrete_codes_mask[:, :last_non_zero_loc]
         
     | 
| 443 | 
         
            +
                        final_audio_out_mask = final_audio_out_mask[:, :last_non_zero_loc]
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
         
     | 
| 446 | 
         
            +
                return (
         
     | 
| 447 | 
         
            +
                    final_embedding,
         
     | 
| 448 | 
         
            +
                    final_attention_mask,
         
     | 
| 449 | 
         
            +
                    final_labels,
         
     | 
| 450 | 
         
            +
                    position_ids,
         
     | 
| 451 | 
         
            +
                    final_input_ids,
         
     | 
| 452 | 
         
            +
                    final_audio_in_mask,
         
     | 
| 453 | 
         
            +
                    final_audio_in_discrete_codes_mask,
         
     | 
| 454 | 
         
            +
                    final_audio_out_mask,
         
     | 
| 455 | 
         
            +
                )
         
     | 
| 456 | 
         
            +
             
     | 
| 457 | 
         
            +
             
     | 
| 458 | 
         
            +
            def is_deepspeed_ulysses_enabled():
         
     | 
| 459 | 
         
            +
                if deepspeed_groups is None:
         
     | 
| 460 | 
         
            +
                    return False
         
     | 
| 461 | 
         
            +
             
     | 
| 462 | 
         
            +
                """Check if sequence parallelism is enabled."""
         
     | 
| 463 | 
         
            +
                return deepspeed_groups._get_sequence_parallel_world_size() > 1
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
             
     | 
| 466 | 
         
            +
            def support_deepspeed_ulysses(module):
         
     | 
| 467 | 
         
            +
                """A decorator around Pytorch module. It is needed for the module that needs access to sequence parallel info."""
         
     | 
| 468 | 
         
            +
                module._sp_size = None
         
     | 
| 469 | 
         
            +
                module._sp_rank = None
         
     | 
| 470 | 
         
            +
                module._sp_group = None
         
     | 
| 471 | 
         
            +
             
     | 
| 472 | 
         
            +
                @property
         
     | 
| 473 | 
         
            +
                def sp_size(self):
         
     | 
| 474 | 
         
            +
                    if self._sp_size is None:
         
     | 
| 475 | 
         
            +
                        self._sp_size = 1
         
     | 
| 476 | 
         
            +
                        if is_deepspeed_ulysses_enabled():
         
     | 
| 477 | 
         
            +
                            self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
         
     | 
| 478 | 
         
            +
                    return self._sp_size
         
     | 
| 479 | 
         
            +
             
     | 
| 480 | 
         
            +
                @property
         
     | 
| 481 | 
         
            +
                def sp_rank(self):
         
     | 
| 482 | 
         
            +
                    if self._sp_rank is None:
         
     | 
| 483 | 
         
            +
                        self._sp_rank = 0
         
     | 
| 484 | 
         
            +
                        if is_deepspeed_ulysses_enabled():
         
     | 
| 485 | 
         
            +
                            self._sp_rank = deepspeed_groups._get_sequence_parallel_rank()
         
     | 
| 486 | 
         
            +
                    return self._sp_rank
         
     | 
| 487 | 
         
            +
             
     | 
| 488 | 
         
            +
                @property
         
     | 
| 489 | 
         
            +
                def sp_group(self):
         
     | 
| 490 | 
         
            +
                    if self._sp_group is None and is_deepspeed_ulysses_enabled():
         
     | 
| 491 | 
         
            +
                        self._sp_group = deepspeed_groups._get_sequence_parallel_group()
         
     | 
| 492 | 
         
            +
                    return self._sp_group
         
     | 
| 493 | 
         
            +
             
     | 
| 494 | 
         
            +
                module.sp_size = sp_size
         
     | 
| 495 | 
         
            +
                module.sp_rank = sp_rank
         
     | 
| 496 | 
         
            +
                module.sp_group = sp_group
         
     | 
| 497 | 
         
            +
             
     | 
| 498 | 
         
            +
                return module
         
     | 
| 499 | 
         
            +
             
     | 
| 500 | 
         
            +
             
     | 
| 501 | 
         
            +
            def deepspeed_ulysses_attention(seq_dim=1, head_dim=2):
         
     | 
| 502 | 
         
            +
                """Perform all-to-all before and after the attention function."""
         
     | 
| 503 | 
         
            +
             
     | 
| 504 | 
         
            +
                def attention_decorator(attn_func=None):
         
     | 
| 505 | 
         
            +
                    def wrapped(*args, **kwargs):
         
     | 
| 506 | 
         
            +
                        if is_deepspeed_ulysses_enabled():
         
     | 
| 507 | 
         
            +
                            sp_group = deepspeed_groups._get_sequence_parallel_group()
         
     | 
| 508 | 
         
            +
                            scatter_idx = head_dim  # Scatter on num_heads dimension
         
     | 
| 509 | 
         
            +
                            gather_idx = seq_dim  # Gather on seq_len dimension
         
     | 
| 510 | 
         
            +
                            batch_dim_idx = 0
         
     | 
| 511 | 
         
            +
                            args = list(args)
         
     | 
| 512 | 
         
            +
                            args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
         
     | 
| 513 | 
         
            +
                            args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
         
     | 
| 514 | 
         
            +
                            args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
         
     | 
| 515 | 
         
            +
                            args = tuple(args)
         
     | 
| 516 | 
         
            +
             
     | 
| 517 | 
         
            +
                        attn_output = attn_func(*args, **kwargs)
         
     | 
| 518 | 
         
            +
             
     | 
| 519 | 
         
            +
                        if is_deepspeed_ulysses_enabled():
         
     | 
| 520 | 
         
            +
                            scatter_idx = seq_dim  # Scatter back on seq_len dimension
         
     | 
| 521 | 
         
            +
                            gather_idx = head_dim  # Gather on num_heads dimension
         
     | 
| 522 | 
         
            +
                            batch_dim_idx = 0
         
     | 
| 523 | 
         
            +
                            attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)
         
     | 
| 524 | 
         
            +
             
     | 
| 525 | 
         
            +
                        return attn_output
         
     | 
| 526 | 
         
            +
             
     | 
| 527 | 
         
            +
                    return wrapped
         
     | 
| 528 | 
         
            +
             
     | 
| 529 | 
         
            +
                return attention_decorator
         
     | 
| 530 | 
         
            +
             
     | 
| 531 | 
         
            +
             
     | 
| 532 | 
         
            +
            def deepspeed_ulysses_rope(state_seq_dim=2, trig_seq_dim=1):
         
     | 
| 533 | 
         
            +
                """Slice the corresponding cos and sin chunks for rope."""
         
     | 
| 534 | 
         
            +
             
     | 
| 535 | 
         
            +
                def rope_decorator(rope_func=None):
         
     | 
| 536 | 
         
            +
                    def wrapped(*args, **kwargs):
         
     | 
| 537 | 
         
            +
                        if is_deepspeed_ulysses_enabled():
         
     | 
| 538 | 
         
            +
                            sp_rank = deepspeed_groups._get_sequence_parallel_rank()
         
     | 
| 539 | 
         
            +
                            args = list(args)
         
     | 
| 540 | 
         
            +
                            seq_chunk_size = args[0].size(state_seq_dim)
         
     | 
| 541 | 
         
            +
                            args[2] = torch.narrow(args[2], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
         
     | 
| 542 | 
         
            +
                            args[3] = torch.narrow(args[3], trig_seq_dim, sp_rank * seq_chunk_size, seq_chunk_size)
         
     | 
| 543 | 
         
            +
                            args = tuple(args)
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
            +
                        return rope_func(*args, **kwargs)
         
     | 
| 546 | 
         
            +
             
     | 
| 547 | 
         
            +
                    return wrapped
         
     | 
| 548 | 
         
            +
             
     | 
| 549 | 
         
            +
                return rope_decorator
         
     | 
| 550 | 
         
            +
             
     | 
| 551 | 
         
            +
             
     | 
| 552 | 
         
            +
            def _gather_tensors(input_, group=None):
         
     | 
| 553 | 
         
            +
                """Gather tensors and concatenate them along a dimension."""
         
     | 
| 554 | 
         
            +
                input_ = input_.contiguous()
         
     | 
| 555 | 
         
            +
                world_size = torch.distributed.get_world_size(group)
         
     | 
| 556 | 
         
            +
                if world_size == 1:
         
     | 
| 557 | 
         
            +
                    return input_
         
     | 
| 558 | 
         
            +
                tensor_shapes = [
         
     | 
| 559 | 
         
            +
                    torch.empty(len(input_.size()), dtype=torch.int64, device=input_.device) for _ in range(world_size)
         
     | 
| 560 | 
         
            +
                ]
         
     | 
| 561 | 
         
            +
                input_size = torch.tensor(input_.size(), dtype=torch.int64, device=input_.device)
         
     | 
| 562 | 
         
            +
                torch.distributed.all_gather(tensor_shapes, input_size, group=group)
         
     | 
| 563 | 
         
            +
                gathered_buffers = [
         
     | 
| 564 | 
         
            +
                    torch.empty(tensor_shapes[i].tolist(), dtype=input_.dtype, device=input_.device) for i in range(world_size)
         
     | 
| 565 | 
         
            +
                ]
         
     | 
| 566 | 
         
            +
                torch.distributed.all_gather(gathered_buffers, input_, group=group)
         
     | 
| 567 | 
         
            +
                return gathered_buffers
         
     | 
| 568 | 
         
            +
             
     | 
| 569 | 
         
            +
             
     | 
| 570 | 
         
            +
            def _scatter_tensors(input_, group=None):
         
     | 
| 571 | 
         
            +
                """Scatter tensors."""
         
     | 
| 572 | 
         
            +
                world_size = torch.distributed.get_world_size(group)
         
     | 
| 573 | 
         
            +
                if world_size == 1:
         
     | 
| 574 | 
         
            +
                    return input_
         
     | 
| 575 | 
         
            +
                rank = torch.distributed.get_rank(group)
         
     | 
| 576 | 
         
            +
                return input_[rank]
         
     | 
| 577 | 
         
            +
             
     | 
| 578 | 
         
            +
             
     | 
| 579 | 
         
            +
            class _GatherTensors(torch.autograd.Function):
         
     | 
| 580 | 
         
            +
                """All gather tensors among the ranks."""
         
     | 
| 581 | 
         
            +
             
     | 
| 582 | 
         
            +
                @staticmethod
         
     | 
| 583 | 
         
            +
                def symbolic(graph, input_, group):
         
     | 
| 584 | 
         
            +
                    return _gather_tensors(input_, group)
         
     | 
| 585 | 
         
            +
             
     | 
| 586 | 
         
            +
                @staticmethod
         
     | 
| 587 | 
         
            +
                def forward(ctx, input_, group):
         
     | 
| 588 | 
         
            +
                    ctx.group = group
         
     | 
| 589 | 
         
            +
                    return torch.nested.as_nested_tensor(_gather_tensors(input_, group), layout=torch.jagged)
         
     | 
| 590 | 
         
            +
             
     | 
| 591 | 
         
            +
                @staticmethod
         
     | 
| 592 | 
         
            +
                def backward(ctx, grad_output):
         
     | 
| 593 | 
         
            +
                    return _scatter_tensors(grad_output, ctx.group), None
         
     | 
| 594 | 
         
            +
             
     | 
| 595 | 
         
            +
             
     | 
| 596 | 
         
            +
            def all_gather_tensors(input_, size=None, dim=0, group=None):
         
     | 
| 597 | 
         
            +
                if torch.distributed.get_world_size(group) == 1:
         
     | 
| 598 | 
         
            +
                    # no sequence parallelism
         
     | 
| 599 | 
         
            +
                    return input_
         
     | 
| 600 | 
         
            +
                gathered_tensors = _GatherTensors.apply(input_, group)
         
     | 
| 601 | 
         
            +
             
     | 
| 602 | 
         
            +
                if size:
         
     | 
| 603 | 
         
            +
                    split_gathered_tensors = []
         
     | 
| 604 | 
         
            +
                    for s, gathered_tensor in zip(size, gathered_tensors):
         
     | 
| 605 | 
         
            +
                        split_gathered_tensor = torch.split(gathered_tensor, s.tolist())
         
     | 
| 606 | 
         
            +
                        split_gathered_tensors.append(split_gathered_tensor)
         
     | 
| 607 | 
         
            +
             
     | 
| 608 | 
         
            +
                    gathered_tensors = [y for x in zip(*split_gathered_tensors) for y in x]
         
     | 
| 609 | 
         
            +
             
     | 
| 610 | 
         
            +
                return torch.cat(gathered_tensors, dim).contiguous()
         
     | 
| 611 | 
         
            +
             
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
            def get_sequence_data_parallel_world_size():
         
     | 
| 614 | 
         
            +
                return torch.distributed.get_world_size()
         
     | 
| 615 | 
         
            +
             
     | 
| 616 | 
         
            +
             
     | 
| 617 | 
         
            +
            def get_sequence_data_parallel_rank():
         
     | 
| 618 | 
         
            +
                return torch.distributed.get_rank()
         
     | 
| 619 | 
         
            +
             
     | 
| 620 | 
         
            +
             
     | 
| 621 | 
         
            +
            def get_sequence_data_parallel_group():
         
     | 
| 622 | 
         
            +
                return torch.distributed.group.WORLD
         
     | 
| 623 | 
         
            +
             
     | 
| 624 | 
         
            +
             
     | 
| 625 | 
         
            +
            if is_deepspeed_available():
         
     | 
| 626 | 
         
            +
                deepspeed_groups._get_sequence_data_parallel_world_size = get_sequence_data_parallel_world_size
         
     | 
| 627 | 
         
            +
                deepspeed_groups._get_sequence_data_parallel_rank = get_sequence_data_parallel_rank
         
     | 
| 628 | 
         
            +
                deepspeed_groups._get_sequence_data_parallel_group = get_sequence_data_parallel_group
         
     | 
| 629 | 
         
            +
             
     | 
| 630 | 
         
            +
             
     | 
| 631 | 
         
            +
            def _gather_tokens(input_, dim=0, group=None):
         
     | 
| 632 | 
         
            +
                """Gather tensors and concatenate them along a dimension"""
         
     | 
| 633 | 
         
            +
                input_ = input_.contiguous()
         
     | 
| 634 | 
         
            +
                world_size = torch.distributed.get_world_size(group)
         
     | 
| 635 | 
         
            +
                if world_size == 1:
         
     | 
| 636 | 
         
            +
                    return input_
         
     | 
| 637 | 
         
            +
             
     | 
| 638 | 
         
            +
                gather_buffer = torch.empty(world_size * input_.numel(), dtype=input_.dtype, device=input_.device)
         
     | 
| 639 | 
         
            +
                torch.distributed.all_gather_into_tensor(gather_buffer, input_, group=group)
         
     | 
| 640 | 
         
            +
                if dim == 0:
         
     | 
| 641 | 
         
            +
                    shape = list(input_.size())
         
     | 
| 642 | 
         
            +
                    shape[0] = shape[0] * world_size
         
     | 
| 643 | 
         
            +
                    output = gather_buffer.view(shape)
         
     | 
| 644 | 
         
            +
                else:
         
     | 
| 645 | 
         
            +
                    tensor_list = [
         
     | 
| 646 | 
         
            +
                        gather_buffer.narrow(0, input_.numel() * i, input_.numel()).view_as(input_) for i in range(world_size)
         
     | 
| 647 | 
         
            +
                    ]
         
     | 
| 648 | 
         
            +
                    # Note: torch.cat already creates a contiguous tensor.
         
     | 
| 649 | 
         
            +
                    output = torch.cat(tensor_list, dim=dim).contiguous()
         
     | 
| 650 | 
         
            +
             
     | 
| 651 | 
         
            +
                return output
         
     | 
| 652 | 
         
            +
             
     | 
| 653 | 
         
            +
             
     | 
| 654 | 
         
            +
            def _drop_tokens(input_, dim=0, group=None):
         
     | 
| 655 | 
         
            +
                """Divide a tensor among the sequence parallel ranks"""
         
     | 
| 656 | 
         
            +
                world_size = torch.distributed.get_world_size(group)
         
     | 
| 657 | 
         
            +
                if world_size == 1:
         
     | 
| 658 | 
         
            +
                    return input_
         
     | 
| 659 | 
         
            +
                this_rank = torch.distributed.get_rank(group)
         
     | 
| 660 | 
         
            +
                assert input_.shape[dim] % world_size == 0, (
         
     | 
| 661 | 
         
            +
                    f"input dimension {dim} ({input_.shape[dim]}) is not divisible by sequence parallel world size ({world_size})"
         
     | 
| 662 | 
         
            +
                )
         
     | 
| 663 | 
         
            +
                chunk_size = input_.shape[dim] // world_size
         
     | 
| 664 | 
         
            +
             
     | 
| 665 | 
         
            +
                return torch.narrow(input_, dim, this_rank * chunk_size, chunk_size)
         
     | 
| 666 | 
         
            +
             
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
            +
            class _DropTokens(torch.autograd.Function):
         
     | 
| 669 | 
         
            +
                "Divide tokens equally among the sequence parallel ranks"
         
     | 
| 670 | 
         
            +
             
     | 
| 671 | 
         
            +
                @staticmethod
         
     | 
| 672 | 
         
            +
                def symbolic(graph, input_, dim, group, grad_scale):
         
     | 
| 673 | 
         
            +
                    return _drop_tokens(input_, dim, group)
         
     | 
| 674 | 
         
            +
             
     | 
| 675 | 
         
            +
                @staticmethod
         
     | 
| 676 | 
         
            +
                def forward(ctx, input_, dim, group, grad_scale):
         
     | 
| 677 | 
         
            +
                    ctx.dim = dim
         
     | 
| 678 | 
         
            +
                    ctx.group = group
         
     | 
| 679 | 
         
            +
                    ctx.grad_scale = grad_scale
         
     | 
| 680 | 
         
            +
                    return _drop_tokens(input_, dim, group)
         
     | 
| 681 | 
         
            +
             
     | 
| 682 | 
         
            +
                @staticmethod
         
     | 
| 683 | 
         
            +
                def backward(ctx, grad_output):
         
     | 
| 684 | 
         
            +
                    grad_input = _gather_tokens(grad_output, ctx.dim, ctx.group)
         
     | 
| 685 | 
         
            +
                    if ctx.grad_scale != 1:
         
     | 
| 686 | 
         
            +
                        grad_input /= ctx.grad_scale
         
     | 
| 687 | 
         
            +
                    return grad_input, None, None, None
         
     | 
| 688 | 
         
            +
             
     | 
| 689 | 
         
            +
             
     | 
| 690 | 
         
            +
            class _GatherTokens(torch.autograd.Function):
         
     | 
| 691 | 
         
            +
                "Gather tokens among the sequence parallel ranks"
         
     | 
| 692 | 
         
            +
             
     | 
| 693 | 
         
            +
                @staticmethod
         
     | 
| 694 | 
         
            +
                def symbolic(graph, input_, dim, group, grad_scale):
         
     | 
| 695 | 
         
            +
                    return _gather_tokens(input_, dim, group)
         
     | 
| 696 | 
         
            +
             
     | 
| 697 | 
         
            +
                @staticmethod
         
     | 
| 698 | 
         
            +
                def forward(ctx, input_, dim, group, grad_scale):
         
     | 
| 699 | 
         
            +
                    ctx.dim = dim
         
     | 
| 700 | 
         
            +
                    ctx.group = group
         
     | 
| 701 | 
         
            +
                    ctx.grad_scale = grad_scale
         
     | 
| 702 | 
         
            +
                    return _gather_tokens(input_, dim, group)
         
     | 
| 703 | 
         
            +
             
     | 
| 704 | 
         
            +
                @staticmethod
         
     | 
| 705 | 
         
            +
                def backward(ctx, grad_output):
         
     | 
| 706 | 
         
            +
                    grad_input = _drop_tokens(grad_output, ctx.dim, ctx.group)
         
     | 
| 707 | 
         
            +
                    if ctx.grad_scale != 1:
         
     | 
| 708 | 
         
            +
                        grad_input *= ctx.grad_scale
         
     | 
| 709 | 
         
            +
                    return grad_input, None, None, None
         
     | 
| 710 | 
         
            +
             
     | 
| 711 | 
         
            +
             
     | 
| 712 | 
         
            +
            def drop_tokens(input_, dim=0, group=None, grad_scale=1):
         
     | 
| 713 | 
         
            +
                if torch.distributed.get_world_size(group) == 1:
         
     | 
| 714 | 
         
            +
                    # no sequence parallelism
         
     | 
| 715 | 
         
            +
                    return input_
         
     | 
| 716 | 
         
            +
                return _DropTokens.apply(input_, dim, group, grad_scale)
         
     | 
| 717 | 
         
            +
             
     | 
| 718 | 
         
            +
             
     | 
| 719 | 
         
            +
            def gather_tokens(input_, dim=0, group=None, grad_scale=1):
         
     | 
| 720 | 
         
            +
                if torch.distributed.get_world_size(group) == 1:
         
     | 
| 721 | 
         
            +
                    # no sequence parallelism
         
     | 
| 722 | 
         
            +
                    return input_
         
     | 
| 723 | 
         
            +
                return _GatherTokens.apply(input_, dim, group, grad_scale)
         
     | 
| 724 | 
         
            +
             
     | 
| 725 | 
         
            +
             
     | 
| 726 | 
         
            +
            def sequence_chunking_per_rank(sp_size, sp_rank, *args, dim=1):
         
     | 
| 727 | 
         
            +
                """
         
     | 
| 728 | 
         
            +
                Slice the inputs to create chuncks per the sequence parallel rank. This is used for the context parallel training.
         
     | 
| 729 | 
         
            +
             
     | 
| 730 | 
         
            +
                Args:
         
     | 
| 731 | 
         
            +
                    sp_size (`int`):
         
     | 
| 732 | 
         
            +
                        Sequence parallel size.
         
     | 
| 733 | 
         
            +
                    sp_rank (`int`):
         
     | 
| 734 | 
         
            +
                        Sequence parallel rank for the current process.
         
     | 
| 735 | 
         
            +
                    dim (`int`):
         
     | 
| 736 | 
         
            +
                       The dimension to slice
         
     | 
| 737 | 
         
            +
                """
         
     | 
| 738 | 
         
            +
                if sp_size == 1:
         
     | 
| 739 | 
         
            +
                    return args[0] if len(args) == 1 else args
         
     | 
| 740 | 
         
            +
             
     | 
| 741 | 
         
            +
                seq_length = args[0].size(dim)
         
     | 
| 742 | 
         
            +
                for arg in args[1:]:
         
     | 
| 743 | 
         
            +
                    assert arg.size(dim) == seq_length, (
         
     | 
| 744 | 
         
            +
                        f"arg={arg} ({arg.shape[dim]}) does not have the same size as args[0] ({seq_length}) in dimension {dim}"
         
     | 
| 745 | 
         
            +
                    )
         
     | 
| 746 | 
         
            +
                assert seq_length % sp_size == 0, (
         
     | 
| 747 | 
         
            +
                    f"dimension {dim} ({args[0].shape[dim]}) is not divisible by sequence parallel world size ({sp_size})"
         
     | 
| 748 | 
         
            +
                )
         
     | 
| 749 | 
         
            +
             
     | 
| 750 | 
         
            +
                sub_seq_length = seq_length // sp_size
         
     | 
| 751 | 
         
            +
                sub_seq_start = sp_rank * sub_seq_length
         
     | 
| 752 | 
         
            +
             
     | 
| 753 | 
         
            +
                output = []
         
     | 
| 754 | 
         
            +
                for ind in args:
         
     | 
| 755 | 
         
            +
                    ind = torch.narrow(ind, dim, sub_seq_start, sub_seq_length)
         
     | 
| 756 | 
         
            +
                    output.append(ind)
         
     | 
| 757 | 
         
            +
             
     | 
| 758 | 
         
            +
                return tuple(output) if len(output) > 1 else output[0]
         
     | 
| 759 | 
         
            +
             
     | 
| 760 | 
         
            +
             
     | 
| 761 | 
         
            +
            @contextmanager
         
     | 
| 762 | 
         
            +
            def disable_deepspeed_ulysses():
         
     | 
| 763 | 
         
            +
                """Disable deepspeed ulysses (sequence parallelism) if it is enabled"""
         
     | 
| 764 | 
         
            +
                if is_deepspeed_ulysses_enabled():
         
     | 
| 765 | 
         
            +
                    _old_get_sequence_parallel_world_size = deepspeed_groups._get_sequence_parallel_world_size
         
     | 
| 766 | 
         
            +
             
     | 
| 767 | 
         
            +
                    def _get_sequence_parallel_world_size():
         
     | 
| 768 | 
         
            +
                        return 1
         
     | 
| 769 | 
         
            +
             
     | 
| 770 | 
         
            +
                    deepspeed_groups._get_sequence_parallel_world_size = _get_sequence_parallel_world_size
         
     | 
| 771 | 
         
            +
                    try:
         
     | 
| 772 | 
         
            +
                        yield
         
     | 
| 773 | 
         
            +
                    finally:
         
     | 
| 774 | 
         
            +
                        deepspeed_groups._get_sequence_parallel_world_size = _old_get_sequence_parallel_world_size
         
     | 
| 775 | 
         
            +
                else:
         
     | 
| 776 | 
         
            +
                    context = contextlib.nullcontext
         
     | 
| 777 | 
         
            +
                    with context():
         
     | 
| 778 | 
         
            +
                        yield
         
     | 
    	
        higgs_audio/serve/serve_engine.py
    ADDED
    
    | 
         @@ -0,0 +1,474 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import asyncio
         
     | 
| 2 | 
         
            +
            import base64
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import numpy as np
         
     | 
| 5 | 
         
            +
            from io import BytesIO
         
     | 
| 6 | 
         
            +
            from dataclasses import dataclass, field
         
     | 
| 7 | 
         
            +
            from typing import List, Optional, Union
         
     | 
| 8 | 
         
            +
            from copy import deepcopy
         
     | 
| 9 | 
         
            +
            from transformers import AutoTokenizer, AutoProcessor
         
     | 
| 10 | 
         
            +
            from transformers.cache_utils import StaticCache
         
     | 
| 11 | 
         
            +
            from transformers.generation.streamers import BaseStreamer
         
     | 
| 12 | 
         
            +
            from transformers.generation.stopping_criteria import StoppingCriteria
         
     | 
| 13 | 
         
            +
            from dataclasses import asdict
         
     | 
| 14 | 
         
            +
            from loguru import logger
         
     | 
| 15 | 
         
            +
            import threading
         
     | 
| 16 | 
         
            +
            import librosa
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            from ..dataset.chatml_dataset import (
         
     | 
| 20 | 
         
            +
                ChatMLSample,
         
     | 
| 21 | 
         
            +
                ChatMLDatasetSample,
         
     | 
| 22 | 
         
            +
                prepare_chatml_sample,
         
     | 
| 23 | 
         
            +
            )
         
     | 
| 24 | 
         
            +
            from ..model import HiggsAudioModel
         
     | 
| 25 | 
         
            +
            from ..model.utils import revert_delay_pattern
         
     | 
| 26 | 
         
            +
            from ..data_collator.higgs_audio_collator import HiggsAudioSampleCollator
         
     | 
| 27 | 
         
            +
            from ..audio_processing.higgs_audio_tokenizer import load_higgs_audio_tokenizer
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
             
     | 
| 30 | 
         
            +
            def normalize_chinese_punctuation(text):
         
     | 
| 31 | 
         
            +
                """
         
     | 
| 32 | 
         
            +
                Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
         
     | 
| 33 | 
         
            +
                """
         
     | 
| 34 | 
         
            +
                # Mapping of Chinese punctuation to English punctuation
         
     | 
| 35 | 
         
            +
                chinese_to_english_punct = {
         
     | 
| 36 | 
         
            +
                    ",": ",",  # comma
         
     | 
| 37 | 
         
            +
                    "。": ".",  # period
         
     | 
| 38 | 
         
            +
                    ":": ":",  # colon
         
     | 
| 39 | 
         
            +
                    ";": ";",  # semicolon
         
     | 
| 40 | 
         
            +
                    "?": "?",  # question mark
         
     | 
| 41 | 
         
            +
                    "!": "!",  # exclamation mark
         
     | 
| 42 | 
         
            +
                    "(": "(",  # left parenthesis
         
     | 
| 43 | 
         
            +
                    ")": ")",  # right parenthesis
         
     | 
| 44 | 
         
            +
                    "【": "[",  # left square bracket
         
     | 
| 45 | 
         
            +
                    "】": "]",  # right square bracket
         
     | 
| 46 | 
         
            +
                    "《": "<",  # left angle quote
         
     | 
| 47 | 
         
            +
                    "》": ">",  # right angle quote
         
     | 
| 48 | 
         
            +
                    "“": '"',  # left double quotation
         
     | 
| 49 | 
         
            +
                    "”": '"',  # right double quotation
         
     | 
| 50 | 
         
            +
                    "‘": "'",  # left single quotation
         
     | 
| 51 | 
         
            +
                    "’": "'",  # right single quotation
         
     | 
| 52 | 
         
            +
                    "、": ",",  # enumeration comma
         
     | 
| 53 | 
         
            +
                    "—": "-",  # em dash
         
     | 
| 54 | 
         
            +
                    "…": "...",  # ellipsis
         
     | 
| 55 | 
         
            +
                    "·": ".",  # middle dot
         
     | 
| 56 | 
         
            +
                    "「": '"',  # left corner bracket
         
     | 
| 57 | 
         
            +
                    "」": '"',  # right corner bracket
         
     | 
| 58 | 
         
            +
                    "『": '"',  # left double corner bracket
         
     | 
| 59 | 
         
            +
                    "』": '"',  # right double corner bracket
         
     | 
| 60 | 
         
            +
                }
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
                # Replace each Chinese punctuation with its English counterpart
         
     | 
| 63 | 
         
            +
                for zh_punct, en_punct in chinese_to_english_punct.items():
         
     | 
| 64 | 
         
            +
                    text = text.replace(zh_punct, en_punct)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                return text
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
             
     | 
| 69 | 
         
            +
            @dataclass
         
     | 
| 70 | 
         
            +
            class HiggsAudioStreamerDelta:
         
     | 
| 71 | 
         
            +
                """Represents a chunk of generated content, either text or audio tokens."""
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                text: Optional[str] = None
         
     | 
| 74 | 
         
            +
                text_tokens: Optional[torch.Tensor] = None
         
     | 
| 75 | 
         
            +
                audio_tokens: Optional[torch.Tensor] = None
         
     | 
| 76 | 
         
            +
                finish_reason: Optional[str] = None
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            class AsyncHiggsAudioStreamer(BaseStreamer):
         
     | 
| 80 | 
         
            +
                """
         
     | 
| 81 | 
         
            +
                Async streamer that handles both text and audio token generation from Higgs-Audio model.
         
     | 
| 82 | 
         
            +
                Stores chunks in a queue to be consumed by downstream applications.
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                Parameters:
         
     | 
| 85 | 
         
            +
                    tokenizer (`AutoTokenizer`):
         
     | 
| 86 | 
         
            +
                        The tokenizer used to decode text tokens.
         
     | 
| 87 | 
         
            +
                    skip_prompt (`bool`, *optional*, defaults to `False`):
         
     | 
| 88 | 
         
            +
                        Whether to skip the prompt tokens in generation.
         
     | 
| 89 | 
         
            +
                    timeout (`float`, *optional*):
         
     | 
| 90 | 
         
            +
                        The timeout for the queue. If `None`, the queue will block indefinitely.
         
     | 
| 91 | 
         
            +
                    decode_kwargs (`dict`, *optional*):
         
     | 
| 92 | 
         
            +
                        Additional keyword arguments to pass to the tokenizer's `decode` method.
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
                Examples:
         
     | 
| 95 | 
         
            +
                    ```python
         
     | 
| 96 | 
         
            +
                    >>> from transformers import AutoTokenizer
         
     | 
| 97 | 
         
            +
                    >>> from threading import Thread
         
     | 
| 98 | 
         
            +
                    >>> import asyncio
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                    >>> tokenizer = AutoTokenizer.from_pretrained("path/to/higgs/tokenizer")
         
     | 
| 101 | 
         
            +
                    >>> model = HiggsAudioModel.from_pretrained("path/to/higgs/model")
         
     | 
| 102 | 
         
            +
                    >>> inputs = tokenizer(["Generate some text and audio:"], return_tensors="pt")
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
                    >>> async def main():
         
     | 
| 105 | 
         
            +
                    ...     streamer = AsyncHiggsAudioStreamer(tokenizer)
         
     | 
| 106 | 
         
            +
                    ...     generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
         
     | 
| 107 | 
         
            +
                    ...     thread = Thread(target=model.generate, kwargs=generation_kwargs)
         
     | 
| 108 | 
         
            +
                    ...     thread.start()
         
     | 
| 109 | 
         
            +
                    ...
         
     | 
| 110 | 
         
            +
                    ...     async for delta in streamer:
         
     | 
| 111 | 
         
            +
                    ...         if delta.text is not None:
         
     | 
| 112 | 
         
            +
                    ...             print("Text:", delta.text)
         
     | 
| 113 | 
         
            +
                    ...         if delta.audio_tokens is not None:
         
     | 
| 114 | 
         
            +
                    ...             print("Audio tokens shape:", delta.audio_tokens.shape)
         
     | 
| 115 | 
         
            +
                    >>> asyncio.run(main())
         
     | 
| 116 | 
         
            +
                    ```
         
     | 
| 117 | 
         
            +
                """
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                def __init__(
         
     | 
| 120 | 
         
            +
                    self,
         
     | 
| 121 | 
         
            +
                    tokenizer: "AutoTokenizer",
         
     | 
| 122 | 
         
            +
                    skip_prompt: bool = False,
         
     | 
| 123 | 
         
            +
                    timeout: Optional[float] = None,
         
     | 
| 124 | 
         
            +
                    audio_num_codebooks: int = 1,
         
     | 
| 125 | 
         
            +
                    **decode_kwargs,
         
     | 
| 126 | 
         
            +
                ):
         
     | 
| 127 | 
         
            +
                    self.tokenizer = tokenizer
         
     | 
| 128 | 
         
            +
                    self.skip_prompt = skip_prompt
         
     | 
| 129 | 
         
            +
                    self.timeout = timeout
         
     | 
| 130 | 
         
            +
                    self.decode_kwargs = decode_kwargs
         
     | 
| 131 | 
         
            +
                    self.audio_num_codebooks = audio_num_codebooks
         
     | 
| 132 | 
         
            +
             
     | 
| 133 | 
         
            +
                    # Queue to store generated chunks
         
     | 
| 134 | 
         
            +
                    self.queue = asyncio.Queue()
         
     | 
| 135 | 
         
            +
                    self.stop_signal = None
         
     | 
| 136 | 
         
            +
             
     | 
| 137 | 
         
            +
                    # Get running event loop
         
     | 
| 138 | 
         
            +
                    self.loop = asyncio.get_running_loop()
         
     | 
| 139 | 
         
            +
                    self.has_asyncio_timeout = hasattr(asyncio, "timeout")
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
                    # State tracking
         
     | 
| 142 | 
         
            +
                    self.next_tokens_are_prompt = True
         
     | 
| 143 | 
         
            +
             
     | 
| 144 | 
         
            +
                def put(self, value: torch.Tensor):
         
     | 
| 145 | 
         
            +
                    """
         
     | 
| 146 | 
         
            +
                    Receives tokens and processes them as either text or audio tokens.
         
     | 
| 147 | 
         
            +
                    For text tokens, decodes and caches them until complete words are formed.
         
     | 
| 148 | 
         
            +
                    For audio tokens, directly queues them.
         
     | 
| 149 | 
         
            +
                    """
         
     | 
| 150 | 
         
            +
                    if value.shape[0] > 1 and not self.next_tokens_are_prompt:
         
     | 
| 151 | 
         
            +
                        # This is likely audio tokens (shape: [audio_num_codebooks])
         
     | 
| 152 | 
         
            +
                        assert value.shape[0] == self.audio_num_codebooks, "Number of codebooks mismatch"
         
     | 
| 153 | 
         
            +
                        delta = HiggsAudioStreamerDelta(audio_tokens=value)
         
     | 
| 154 | 
         
            +
                        self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
         
     | 
| 155 | 
         
            +
                        return
         
     | 
| 156 | 
         
            +
             
     | 
| 157 | 
         
            +
                    # Skip prompt tokens if configured
         
     | 
| 158 | 
         
            +
                    if self.skip_prompt and self.next_tokens_are_prompt:
         
     | 
| 159 | 
         
            +
                        self.next_tokens_are_prompt = False
         
     | 
| 160 | 
         
            +
                        return
         
     | 
| 161 | 
         
            +
             
     | 
| 162 | 
         
            +
                    # Process as text tokens
         
     | 
| 163 | 
         
            +
                    if len(value.shape) > 1:
         
     | 
| 164 | 
         
            +
                        value = value[0]
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
                    text = self.tokenizer.decode(value, **self.decode_kwargs)
         
     | 
| 167 | 
         
            +
                    delta = HiggsAudioStreamerDelta(text=text, text_tokens=value)
         
     | 
| 168 | 
         
            +
                    self.loop.call_soon_threadsafe(self.queue.put_nowait, delta)
         
     | 
| 169 | 
         
            +
             
     | 
| 170 | 
         
            +
                def end(self):
         
     | 
| 171 | 
         
            +
                    """Flushes any remaining text tokens and signals the end of generation."""
         
     | 
| 172 | 
         
            +
                    self.next_tokens_are_prompt = True
         
     | 
| 173 | 
         
            +
                    self.loop.call_soon_threadsafe(self.queue.put_nowait, self.stop_signal)
         
     | 
| 174 | 
         
            +
             
     | 
| 175 | 
         
            +
                def __aiter__(self):
         
     | 
| 176 | 
         
            +
                    return self
         
     | 
| 177 | 
         
            +
             
     | 
| 178 | 
         
            +
                async def __anext__(self):
         
     | 
| 179 | 
         
            +
                    try:
         
     | 
| 180 | 
         
            +
                        if self.has_asyncio_timeout:
         
     | 
| 181 | 
         
            +
                            async with asyncio.timeout(self.timeout):
         
     | 
| 182 | 
         
            +
                                value = await self.queue.get()
         
     | 
| 183 | 
         
            +
                        else:
         
     | 
| 184 | 
         
            +
                            value = await asyncio.wait_for(self.queue.get(), timeout=self.timeout)
         
     | 
| 185 | 
         
            +
                    except asyncio.TimeoutError:
         
     | 
| 186 | 
         
            +
                        raise TimeoutError()
         
     | 
| 187 | 
         
            +
                    else:
         
     | 
| 188 | 
         
            +
                        if value == self.stop_signal:
         
     | 
| 189 | 
         
            +
                            raise StopAsyncIteration()
         
     | 
| 190 | 
         
            +
                        else:
         
     | 
| 191 | 
         
            +
                            return value
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
             
     | 
| 194 | 
         
            +
            class AsyncStoppingCriteria(StoppingCriteria):
         
     | 
| 195 | 
         
            +
                """
         
     | 
| 196 | 
         
            +
                Stopping criteria that checks for stop signal from a threading event.
         
     | 
| 197 | 
         
            +
             
     | 
| 198 | 
         
            +
                Args:
         
     | 
| 199 | 
         
            +
                    stop_signal (threading.Event): Event that will receive stop signals
         
     | 
| 200 | 
         
            +
                """
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                def __init__(self, stop_signal: threading.Event):
         
     | 
| 203 | 
         
            +
                    self.stop_signal = stop_signal
         
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
                def __call__(self, input_ids, scores, **kwargs) -> bool:
         
     | 
| 206 | 
         
            +
                    if self.stop_signal.is_set():
         
     | 
| 207 | 
         
            +
                        logger.info(f"Stop signal received. Can be caused by client disconnection.")
         
     | 
| 208 | 
         
            +
                        return True
         
     | 
| 209 | 
         
            +
                    return False
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
            @dataclass
         
     | 
| 213 | 
         
            +
            class HiggsAudioResponse:
         
     | 
| 214 | 
         
            +
                audio: Optional[np.ndarray] = None
         
     | 
| 215 | 
         
            +
                generated_audio_tokens: Optional[np.ndarray] = None
         
     | 
| 216 | 
         
            +
                sampling_rate: Optional[int] = None
         
     | 
| 217 | 
         
            +
                generated_text: str = ""
         
     | 
| 218 | 
         
            +
                generated_text_tokens: np.ndarray = field(default_factory=np.ndarray)
         
     | 
| 219 | 
         
            +
                usage: Optional[dict] = None
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
             
     | 
| 222 | 
         
            +
            class HiggsAudioServeEngine:
         
     | 
| 223 | 
         
            +
                def __init__(
         
     | 
| 224 | 
         
            +
                    self,
         
     | 
| 225 | 
         
            +
                    model_name_or_path: str,
         
     | 
| 226 | 
         
            +
                    audio_tokenizer_name_or_path: str,
         
     | 
| 227 | 
         
            +
                    tokenizer_name_or_path: Optional[str] = None,
         
     | 
| 228 | 
         
            +
                    device: str = "cuda",
         
     | 
| 229 | 
         
            +
                    torch_dtype: Union[torch.dtype, str] = "auto",
         
     | 
| 230 | 
         
            +
                    kv_cache_lengths: List[int] = [1024, 4096, 8192],  # Multiple KV cache sizes
         
     | 
| 231 | 
         
            +
                ):
         
     | 
| 232 | 
         
            +
                    """
         
     | 
| 233 | 
         
            +
                    Initialize the HiggsAudioServeEngine, a serving wrapper for the HiggsAudioModel.
         
     | 
| 234 | 
         
            +
                    The model, tokenizer, and audio tokenizer will be downloaded from the Hugging Face Hub if they are not local.
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                    Args:
         
     | 
| 237 | 
         
            +
                        model_name_or_path (str):
         
     | 
| 238 | 
         
            +
                            The name or path of the model to load.
         
     | 
| 239 | 
         
            +
                        audio_tokenizer_name_or_path (str):
         
     | 
| 240 | 
         
            +
                            The name or path of the audio tokenizer to load.
         
     | 
| 241 | 
         
            +
                        tokenizer_name_or_path (str):
         
     | 
| 242 | 
         
            +
                            The name or path of the tokenizer to load.
         
     | 
| 243 | 
         
            +
                        device (str):
         
     | 
| 244 | 
         
            +
                            The device to use for the model.
         
     | 
| 245 | 
         
            +
                        kv_cache_lengths (List[int]):
         
     | 
| 246 | 
         
            +
                            The lengths of the KV caches to use for the model. Used for cuda graph capture when device is cuda.
         
     | 
| 247 | 
         
            +
                        torch_dtype (Union[torch.dtype, str]):
         
     | 
| 248 | 
         
            +
                            The dtype to use for the model.
         
     | 
| 249 | 
         
            +
                    """
         
     | 
| 250 | 
         
            +
                    self.device = device
         
     | 
| 251 | 
         
            +
                    self.model_name_or_path = model_name_or_path
         
     | 
| 252 | 
         
            +
                    self.torch_dtype = torch_dtype
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                    # Initialize model and tokenizer
         
     | 
| 255 | 
         
            +
                    self.model = HiggsAudioModel.from_pretrained(model_name_or_path, torch_dtype=torch_dtype).to(device)
         
     | 
| 256 | 
         
            +
                    logger.info(f"Loaded model from {model_name_or_path}, dtype: {self.model.dtype}")
         
     | 
| 257 | 
         
            +
             
     | 
| 258 | 
         
            +
                    if tokenizer_name_or_path is None:
         
     | 
| 259 | 
         
            +
                        tokenizer_name_or_path = model_name_or_path
         
     | 
| 260 | 
         
            +
                    logger.info(f"Loading tokenizer from {tokenizer_name_or_path}")
         
     | 
| 261 | 
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
         
     | 
| 262 | 
         
            +
             
     | 
| 263 | 
         
            +
                    logger.info(f"Initializing Higgs Audio Tokenizer")
         
     | 
| 264 | 
         
            +
                    self.audio_tokenizer = load_higgs_audio_tokenizer(audio_tokenizer_name_or_path, device=device)
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
            +
                    self.audio_num_codebooks = self.model.config.audio_num_codebooks
         
     | 
| 267 | 
         
            +
                    self.audio_codebook_size = self.model.config.audio_codebook_size
         
     | 
| 268 | 
         
            +
                    self.audio_tokenizer_tps = self.audio_tokenizer.tps
         
     | 
| 269 | 
         
            +
                    self.samples_per_token = int(self.audio_tokenizer.sampling_rate // self.audio_tokenizer_tps)
         
     | 
| 270 | 
         
            +
                    self.hamming_window_len = 2 * self.audio_num_codebooks * self.samples_per_token
         
     | 
| 271 | 
         
            +
                    # Set the audio special tokens
         
     | 
| 272 | 
         
            +
                    self.model.set_audio_special_tokens(self.tokenizer)
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    # Prepare KV caches for different lengths
         
     | 
| 275 | 
         
            +
                    cache_config = deepcopy(self.model.config.text_config)
         
     | 
| 276 | 
         
            +
                    cache_config.num_hidden_layers = self.model.config.text_config.num_hidden_layers
         
     | 
| 277 | 
         
            +
                    if self.model.config.audio_dual_ffn_layers:
         
     | 
| 278 | 
         
            +
                        cache_config.num_hidden_layers += len(self.model.config.audio_dual_ffn_layers)
         
     | 
| 279 | 
         
            +
                    # A list of KV caches for different lengths
         
     | 
| 280 | 
         
            +
                    self.kv_caches = {
         
     | 
| 281 | 
         
            +
                        length: StaticCache(
         
     | 
| 282 | 
         
            +
                            config=cache_config,
         
     | 
| 283 | 
         
            +
                            max_batch_size=1,
         
     | 
| 284 | 
         
            +
                            max_cache_len=length,
         
     | 
| 285 | 
         
            +
                            device=self.model.device,
         
     | 
| 286 | 
         
            +
                            dtype=self.model.dtype,
         
     | 
| 287 | 
         
            +
                        )
         
     | 
| 288 | 
         
            +
                        for length in sorted(kv_cache_lengths)
         
     | 
| 289 | 
         
            +
                    }
         
     | 
| 290 | 
         
            +
             
     | 
| 291 | 
         
            +
                    if self.model.config.encode_whisper_embed:
         
     | 
| 292 | 
         
            +
                        logger.info(f"Loading whisper processor")
         
     | 
| 293 | 
         
            +
                        whisper_processor = AutoProcessor.from_pretrained(
         
     | 
| 294 | 
         
            +
                            "openai/whisper-large-v3-turbo",
         
     | 
| 295 | 
         
            +
                            trust_remote=True,
         
     | 
| 296 | 
         
            +
                            device=self.device,
         
     | 
| 297 | 
         
            +
                        )
         
     | 
| 298 | 
         
            +
                    else:
         
     | 
| 299 | 
         
            +
                        whisper_processor = None
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    # Reuse collator to prepare inference samples
         
     | 
| 302 | 
         
            +
                    self.collator = HiggsAudioSampleCollator(
         
     | 
| 303 | 
         
            +
                        whisper_processor=whisper_processor,
         
     | 
| 304 | 
         
            +
                        encode_whisper_embed=self.model.config.encode_whisper_embed,
         
     | 
| 305 | 
         
            +
                        audio_in_token_id=self.model.config.audio_in_token_idx,
         
     | 
| 306 | 
         
            +
                        audio_out_token_id=self.model.config.audio_out_token_idx,
         
     | 
| 307 | 
         
            +
                        audio_stream_bos_id=self.model.config.audio_stream_bos_id,
         
     | 
| 308 | 
         
            +
                        audio_stream_eos_id=self.model.config.audio_stream_eos_id,
         
     | 
| 309 | 
         
            +
                        pad_token_id=self.model.config.pad_token_id,
         
     | 
| 310 | 
         
            +
                        return_audio_in_tokens=False,
         
     | 
| 311 | 
         
            +
                        use_delay_pattern=self.model.config.use_delay_pattern,
         
     | 
| 312 | 
         
            +
                        audio_num_codebooks=self.model.config.audio_num_codebooks,
         
     | 
| 313 | 
         
            +
                        round_to=1,
         
     | 
| 314 | 
         
            +
                    )
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    # Lock to prevent multiple generations from happening at the same time
         
     | 
| 317 | 
         
            +
                    self.generate_lock = threading.Lock()
         
     | 
| 318 | 
         
            +
             
     | 
| 319 | 
         
            +
                    # Capture CUDA graphs for each KV cache length
         
     | 
| 320 | 
         
            +
                    # if device == "cuda":
         
     | 
| 321 | 
         
            +
                    #     logger.info(f"Capturing CUDA graphs for each KV cache length")
         
     | 
| 322 | 
         
            +
                    #     self.model.capture_model(self.kv_caches.values())
         
     | 
| 323 | 
         
            +
             
     | 
| 324 | 
         
            +
                def _prepare_inputs(self, chat_ml_sample: ChatMLSample, force_audio_gen: bool = False):
         
     | 
| 325 | 
         
            +
                    input_tokens, _, audio_contents, _ = prepare_chatml_sample(
         
     | 
| 326 | 
         
            +
                        chat_ml_sample,
         
     | 
| 327 | 
         
            +
                        self.tokenizer,
         
     | 
| 328 | 
         
            +
                    )
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                    postfix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
         
     | 
| 331 | 
         
            +
                    if force_audio_gen:
         
     | 
| 332 | 
         
            +
                        postfix += "<|audio_out_bos|>"
         
     | 
| 333 | 
         
            +
                    postfix = self.tokenizer.encode(postfix, add_special_tokens=False)
         
     | 
| 334 | 
         
            +
                    input_tokens.extend(postfix)
         
     | 
| 335 | 
         
            +
             
     | 
| 336 | 
         
            +
                    # Configure the audio inputs
         
     | 
| 337 | 
         
            +
                    audio_ids_l = []
         
     | 
| 338 | 
         
            +
                    for audio_content in audio_contents:
         
     | 
| 339 | 
         
            +
                        if audio_content.audio_url not in ["placeholder", ""]:
         
     | 
| 340 | 
         
            +
                            raw_audio, _ = librosa.load(audio_content.audio_url, sr=self.audio_tokenizer.sampling_rate)
         
     | 
| 341 | 
         
            +
                        elif audio_content.raw_audio is not None:
         
     | 
| 342 | 
         
            +
                            raw_audio, _ = librosa.load(
         
     | 
| 343 | 
         
            +
                                BytesIO(base64.b64decode(audio_content.raw_audio)),
         
     | 
| 344 | 
         
            +
                                sr=self.audio_tokenizer.sampling_rate,
         
     | 
| 345 | 
         
            +
                            )
         
     | 
| 346 | 
         
            +
                        else:
         
     | 
| 347 | 
         
            +
                            raw_audio = None
         
     | 
| 348 | 
         
            +
             
     | 
| 349 | 
         
            +
                        if raw_audio is not None:
         
     | 
| 350 | 
         
            +
                            audio_ids = self.audio_tokenizer.encode(raw_audio, self.audio_tokenizer.sampling_rate)
         
     | 
| 351 | 
         
            +
                            audio_ids_l.append(audio_ids.squeeze(0).cpu())
         
     | 
| 352 | 
         
            +
             
     | 
| 353 | 
         
            +
                    if len(audio_ids_l) > 0:
         
     | 
| 354 | 
         
            +
                        audio_ids_start = torch.tensor(
         
     | 
| 355 | 
         
            +
                            np.cumsum(np.array([0] + [audio_ids.shape[1] for audio_ids in audio_ids_l])),
         
     | 
| 356 | 
         
            +
                            dtype=torch.long,
         
     | 
| 357 | 
         
            +
                            device=self.device,
         
     | 
| 358 | 
         
            +
                        )[0:-1]
         
     | 
| 359 | 
         
            +
                        audio_ids_concat = torch.cat(audio_ids_l, dim=1)
         
     | 
| 360 | 
         
            +
                    else:
         
     | 
| 361 | 
         
            +
                        audio_ids_start = None
         
     | 
| 362 | 
         
            +
                        audio_ids_concat = None
         
     | 
| 363 | 
         
            +
             
     | 
| 364 | 
         
            +
                    sample = ChatMLDatasetSample(
         
     | 
| 365 | 
         
            +
                        input_ids=torch.LongTensor(input_tokens),
         
     | 
| 366 | 
         
            +
                        label_ids=None,
         
     | 
| 367 | 
         
            +
                        audio_ids_concat=audio_ids_concat,
         
     | 
| 368 | 
         
            +
                        audio_ids_start=audio_ids_start,
         
     | 
| 369 | 
         
            +
                        audio_waveforms_concat=None,
         
     | 
| 370 | 
         
            +
                        audio_waveforms_start=None,
         
     | 
| 371 | 
         
            +
                        audio_sample_rate=None,
         
     | 
| 372 | 
         
            +
                        audio_speaker_indices=None,
         
     | 
| 373 | 
         
            +
                    )
         
     | 
| 374 | 
         
            +
                    data = self.collator([sample])
         
     | 
| 375 | 
         
            +
                    inputs = asdict(data)
         
     | 
| 376 | 
         
            +
                    for k, v in inputs.items():
         
     | 
| 377 | 
         
            +
                        if isinstance(v, torch.Tensor):
         
     | 
| 378 | 
         
            +
                            inputs[k] = v.to(self.model.device)
         
     | 
| 379 | 
         
            +
             
     | 
| 380 | 
         
            +
                    return inputs
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
                def _prepare_kv_caches(self):
         
     | 
| 383 | 
         
            +
                    for kv_cache in self.kv_caches.values():
         
     | 
| 384 | 
         
            +
                        kv_cache.reset()
         
     | 
| 385 | 
         
            +
             
     | 
| 386 | 
         
            +
                def generate(
         
     | 
| 387 | 
         
            +
                    self,
         
     | 
| 388 | 
         
            +
                    chat_ml_sample: ChatMLSample,
         
     | 
| 389 | 
         
            +
                    max_new_tokens: int,
         
     | 
| 390 | 
         
            +
                    temperature: float = 0.7,
         
     | 
| 391 | 
         
            +
                    top_k: Optional[int] = None,
         
     | 
| 392 | 
         
            +
                    top_p: float = 0.95,
         
     | 
| 393 | 
         
            +
                    stop_strings: Optional[List[str]] = None,
         
     | 
| 394 | 
         
            +
                    force_audio_gen: bool = False,
         
     | 
| 395 | 
         
            +
                    ras_win_len: Optional[int] = None,
         
     | 
| 396 | 
         
            +
                    ras_win_max_num_repeat: int = 2,
         
     | 
| 397 | 
         
            +
                ):
         
     | 
| 398 | 
         
            +
                    """
         
     | 
| 399 | 
         
            +
                    Generate audio from a chatml sample.
         
     | 
| 400 | 
         
            +
                    Args:
         
     | 
| 401 | 
         
            +
                        chat_ml_sample: A chatml sample.
         
     | 
| 402 | 
         
            +
                        max_new_tokens: The maximum number of new tokens to generate.
         
     | 
| 403 | 
         
            +
                        temperature: The temperature to use for the generation.
         
     | 
| 404 | 
         
            +
                        top_p: The top p to use for the generation.
         
     | 
| 405 | 
         
            +
                    Returns:
         
     | 
| 406 | 
         
            +
                        A dictionary with the following keys:
         
     | 
| 407 | 
         
            +
                            audio: The generated audio.
         
     | 
| 408 | 
         
            +
                            sampling_rate: The sampling rate of the generated audio.
         
     | 
| 409 | 
         
            +
                    """
         
     | 
| 410 | 
         
            +
                    # Default stop strings
         
     | 
| 411 | 
         
            +
                    if stop_strings is None:
         
     | 
| 412 | 
         
            +
                        stop_strings = ["<|end_of_text|>", "<|eot_id|>"]
         
     | 
| 413 | 
         
            +
             
     | 
| 414 | 
         
            +
                    with torch.no_grad(), self.generate_lock:
         
     | 
| 415 | 
         
            +
                        inputs = self._prepare_inputs(chat_ml_sample, force_audio_gen=force_audio_gen)
         
     | 
| 416 | 
         
            +
                        prompt_token_ids = inputs["input_ids"][0].cpu().numpy()
         
     | 
| 417 | 
         
            +
             
     | 
| 418 | 
         
            +
                        self._prepare_kv_caches()
         
     | 
| 419 | 
         
            +
             
     | 
| 420 | 
         
            +
                        outputs = self.model.generate(
         
     | 
| 421 | 
         
            +
                            **inputs,
         
     | 
| 422 | 
         
            +
                            max_new_tokens=max_new_tokens,
         
     | 
| 423 | 
         
            +
                            use_cache=True,
         
     | 
| 424 | 
         
            +
                            stop_strings=stop_strings,
         
     | 
| 425 | 
         
            +
                            tokenizer=self.tokenizer,
         
     | 
| 426 | 
         
            +
                            do_sample=False if temperature == 0.0 else True,
         
     | 
| 427 | 
         
            +
                            temperature=temperature,
         
     | 
| 428 | 
         
            +
                            top_k=top_k,
         
     | 
| 429 | 
         
            +
                            top_p=top_p,
         
     | 
| 430 | 
         
            +
                            past_key_values_buckets=self.kv_caches,
         
     | 
| 431 | 
         
            +
                            ras_win_len=ras_win_len,
         
     | 
| 432 | 
         
            +
                            ras_win_max_num_repeat=ras_win_max_num_repeat,
         
     | 
| 433 | 
         
            +
                        )
         
     | 
| 434 | 
         
            +
             
     | 
| 435 | 
         
            +
                        if len(outputs[1]) > 0:
         
     | 
| 436 | 
         
            +
                            wv_list = []
         
     | 
| 437 | 
         
            +
                            for output_audio in outputs[1]:
         
     | 
| 438 | 
         
            +
                                vq_code = revert_delay_pattern(output_audio).clip(0, self.audio_codebook_size - 1)[:, 1:-1]
         
     | 
| 439 | 
         
            +
                                wv_numpy = self.audio_tokenizer.decode(vq_code.unsqueeze(0))[0, 0]
         
     | 
| 440 | 
         
            +
                                wv_list.append(wv_numpy)
         
     | 
| 441 | 
         
            +
                            wv_numpy = np.concatenate(wv_list)
         
     | 
| 442 | 
         
            +
                        else:
         
     | 
| 443 | 
         
            +
                            wv_numpy = None
         
     | 
| 444 | 
         
            +
             
     | 
| 445 | 
         
            +
                        # We only support one request at a time now
         
     | 
| 446 | 
         
            +
                        generated_text_tokens = outputs[0][0].cpu().numpy()[len(prompt_token_ids) :]
         
     | 
| 447 | 
         
            +
                        generated_text = self.tokenizer.decode(generated_text_tokens)
         
     | 
| 448 | 
         
            +
                        generated_audio_tokens = outputs[1][0].cpu().numpy()
         
     | 
| 449 | 
         
            +
                        return HiggsAudioResponse(
         
     | 
| 450 | 
         
            +
                            audio=wv_numpy,
         
     | 
| 451 | 
         
            +
                            generated_audio_tokens=generated_audio_tokens,
         
     | 
| 452 | 
         
            +
                            sampling_rate=self.audio_tokenizer.sampling_rate,
         
     | 
| 453 | 
         
            +
                            generated_text=generated_text,
         
     | 
| 454 | 
         
            +
                            generated_text_tokens=generated_text_tokens,
         
     | 
| 455 | 
         
            +
                            usage={
         
     | 
| 456 | 
         
            +
                                "prompt_tokens": prompt_token_ids.shape[0],
         
     | 
| 457 | 
         
            +
                                "completion_tokens": generated_text_tokens.shape[0] + generated_audio_tokens.shape[1],
         
     | 
| 458 | 
         
            +
                                "total_tokens": (
         
     | 
| 459 | 
         
            +
                                    prompt_token_ids.shape[0] + generated_text_tokens.shape[0] + generated_audio_tokens.shape[1]
         
     | 
| 460 | 
         
            +
                                ),
         
     | 
| 461 | 
         
            +
                                "cached_tokens": 0,
         
     | 
| 462 | 
         
            +
                            },
         
     | 
| 463 | 
         
            +
                        )
         
     | 
| 464 | 
         
            +
             
     | 
| 465 | 
         
            +
                def text_normalize(self, text: str) -> str:
         
     | 
| 466 | 
         
            +
                    """
         
     | 
| 467 | 
         
            +
                    Normalize the text.
         
     | 
| 468 | 
         
            +
                    """
         
     | 
| 469 | 
         
            +
                    # Perform some basic normalization
         
     | 
| 470 | 
         
            +
                    text = normalize_chinese_punctuation(text)
         
     | 
| 471 | 
         
            +
                    # Handle parentheses
         
     | 
| 472 | 
         
            +
                    text = text.replace("(", " ")
         
     | 
| 473 | 
         
            +
                    text = text.replace(")", " ")
         
     | 
| 474 | 
         
            +
                    return text
         
     | 
    	
        higgs_audio/serve/utils.py
    ADDED
    
    | 
         @@ -0,0 +1,254 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import uuid
         
     | 
| 2 | 
         
            +
            import base64
         
     | 
| 3 | 
         
            +
            import re
         
     | 
| 4 | 
         
            +
            import regex
         
     | 
| 5 | 
         
            +
            from typing import AsyncGenerator, Union
         
     | 
| 6 | 
         
            +
            import io
         
     | 
| 7 | 
         
            +
            from pydub import AudioSegment
         
     | 
| 8 | 
         
            +
            import torch
         
     | 
| 9 | 
         
            +
            import numpy as np
         
     | 
| 10 | 
         
            +
            from functools import lru_cache
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
            from ..audio_processing.higgs_audio_tokenizer import HiggsAudioTokenizer
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            def random_uuid() -> str:
         
     | 
| 16 | 
         
            +
                return str(uuid.uuid4().hex)
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
            +
            async def async_generator_wrap(first_element, gen: AsyncGenerator):
         
     | 
| 20 | 
         
            +
                """Wrap an async generator with the first element."""
         
     | 
| 21 | 
         
            +
                yield first_element
         
     | 
| 22 | 
         
            +
                async for item in gen:
         
     | 
| 23 | 
         
            +
                    yield item
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
            +
            @lru_cache(maxsize=50)
         
     | 
| 27 | 
         
            +
            def encode_base64_content_from_file(file_path: str) -> str:
         
     | 
| 28 | 
         
            +
                """Encode a content from a local file to base64 format."""
         
     | 
| 29 | 
         
            +
                # Read the MP3 file as binary and encode it directly to Base64
         
     | 
| 30 | 
         
            +
                with open(file_path, "rb") as audio_file:
         
     | 
| 31 | 
         
            +
                    audio_base64 = base64.b64encode(audio_file.read()).decode("utf-8")
         
     | 
| 32 | 
         
            +
                return audio_base64
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            def pcm16_to_target_format(
         
     | 
| 36 | 
         
            +
                np_audio: np.ndarray,
         
     | 
| 37 | 
         
            +
                sample_rate: int,
         
     | 
| 38 | 
         
            +
                bit_depth: int,
         
     | 
| 39 | 
         
            +
                channels: int,
         
     | 
| 40 | 
         
            +
                format: str,
         
     | 
| 41 | 
         
            +
                target_rate: int,
         
     | 
| 42 | 
         
            +
            ):
         
     | 
| 43 | 
         
            +
                wav_audio = AudioSegment(
         
     | 
| 44 | 
         
            +
                    np_audio.tobytes(),
         
     | 
| 45 | 
         
            +
                    frame_rate=sample_rate,
         
     | 
| 46 | 
         
            +
                    sample_width=bit_depth // 8,
         
     | 
| 47 | 
         
            +
                    channels=channels,
         
     | 
| 48 | 
         
            +
                )
         
     | 
| 49 | 
         
            +
                if target_rate is not None and target_rate != sample_rate:
         
     | 
| 50 | 
         
            +
                    wav_audio = wav_audio.set_frame_rate(target_rate)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                # Convert WAV to MP3
         
     | 
| 53 | 
         
            +
                target_io = io.BytesIO()
         
     | 
| 54 | 
         
            +
                wav_audio.export(target_io, format=format)
         
     | 
| 55 | 
         
            +
                target_io.seek(0)
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                return target_io
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
            chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def contains_chinese(text: str):
         
     | 
| 64 | 
         
            +
                return bool(chinese_char_pattern.search(text))
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            # remove blank between chinese character
         
     | 
| 68 | 
         
            +
            def replace_blank(text: str):
         
     | 
| 69 | 
         
            +
                out_str = []
         
     | 
| 70 | 
         
            +
                for i, c in enumerate(text):
         
     | 
| 71 | 
         
            +
                    if c == " ":
         
     | 
| 72 | 
         
            +
                        if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
         
     | 
| 73 | 
         
            +
                            out_str.append(c)
         
     | 
| 74 | 
         
            +
                    else:
         
     | 
| 75 | 
         
            +
                        out_str.append(c)
         
     | 
| 76 | 
         
            +
                return "".join(out_str)
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
             
     | 
| 79 | 
         
            +
            def replace_corner_mark(text: str):
         
     | 
| 80 | 
         
            +
                text = text.replace("²", "平方")
         
     | 
| 81 | 
         
            +
                text = text.replace("³", "立方")
         
     | 
| 82 | 
         
            +
                return text
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # remove meaningless symbol
         
     | 
| 86 | 
         
            +
            def remove_bracket(text: str):
         
     | 
| 87 | 
         
            +
                text = text.replace("(", "").replace(")", "")
         
     | 
| 88 | 
         
            +
                text = text.replace("【", "").replace("】", "")
         
     | 
| 89 | 
         
            +
                text = text.replace("`", "").replace("`", "")
         
     | 
| 90 | 
         
            +
                text = text.replace("——", " ")
         
     | 
| 91 | 
         
            +
                return text
         
     | 
| 92 | 
         
            +
             
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
            +
            # split paragrah logic:
         
     | 
| 95 | 
         
            +
            # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
         
     | 
| 96 | 
         
            +
            # 2. cal sentence len according to lang
         
     | 
| 97 | 
         
            +
            # 3. split sentence according to puncatation
         
     | 
| 98 | 
         
            +
            def split_paragraph(
         
     | 
| 99 | 
         
            +
                text: str,
         
     | 
| 100 | 
         
            +
                tokenize,
         
     | 
| 101 | 
         
            +
                lang="zh",
         
     | 
| 102 | 
         
            +
                token_max_n=80,
         
     | 
| 103 | 
         
            +
                token_min_n=60,
         
     | 
| 104 | 
         
            +
                merge_len=20,
         
     | 
| 105 | 
         
            +
                comma_split=False,
         
     | 
| 106 | 
         
            +
            ):
         
     | 
| 107 | 
         
            +
                def calc_utt_length(_text: str):
         
     | 
| 108 | 
         
            +
                    if lang == "zh":
         
     | 
| 109 | 
         
            +
                        return len(_text)
         
     | 
| 110 | 
         
            +
                    else:
         
     | 
| 111 | 
         
            +
                        return len(tokenize(_text))
         
     | 
| 112 | 
         
            +
             
     | 
| 113 | 
         
            +
                def should_merge(_text: str):
         
     | 
| 114 | 
         
            +
                    if lang == "zh":
         
     | 
| 115 | 
         
            +
                        return len(_text) < merge_len
         
     | 
| 116 | 
         
            +
                    else:
         
     | 
| 117 | 
         
            +
                        return len(tokenize(_text)) < merge_len
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                if lang == "zh":
         
     | 
| 120 | 
         
            +
                    pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
         
     | 
| 121 | 
         
            +
                else:
         
     | 
| 122 | 
         
            +
                    pounc = [".", "?", "!", ";", ":"]
         
     | 
| 123 | 
         
            +
                if comma_split:
         
     | 
| 124 | 
         
            +
                    pounc.extend([",", ","])
         
     | 
| 125 | 
         
            +
             
     | 
| 126 | 
         
            +
                if text[-1] not in pounc:
         
     | 
| 127 | 
         
            +
                    if lang == "zh":
         
     | 
| 128 | 
         
            +
                        text += "。"
         
     | 
| 129 | 
         
            +
                    else:
         
     | 
| 130 | 
         
            +
                        text += "."
         
     | 
| 131 | 
         
            +
             
     | 
| 132 | 
         
            +
                st = 0
         
     | 
| 133 | 
         
            +
                utts = []
         
     | 
| 134 | 
         
            +
                for i, c in enumerate(text):
         
     | 
| 135 | 
         
            +
                    if c in pounc:
         
     | 
| 136 | 
         
            +
                        if len(text[st:i]) > 0:
         
     | 
| 137 | 
         
            +
                            utts.append(text[st:i] + c)
         
     | 
| 138 | 
         
            +
                        if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
         
     | 
| 139 | 
         
            +
                            tmp = utts.pop(-1)
         
     | 
| 140 | 
         
            +
                            utts.append(tmp + text[i + 1])
         
     | 
| 141 | 
         
            +
                            st = i + 2
         
     | 
| 142 | 
         
            +
                        else:
         
     | 
| 143 | 
         
            +
                            st = i + 1
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                final_utts = []
         
     | 
| 146 | 
         
            +
                cur_utt = ""
         
     | 
| 147 | 
         
            +
                for utt in utts:
         
     | 
| 148 | 
         
            +
                    if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
         
     | 
| 149 | 
         
            +
                        final_utts.append(cur_utt)
         
     | 
| 150 | 
         
            +
                        cur_utt = ""
         
     | 
| 151 | 
         
            +
                    cur_utt = cur_utt + utt
         
     | 
| 152 | 
         
            +
                if len(cur_utt) > 0:
         
     | 
| 153 | 
         
            +
                    if should_merge(cur_utt) and len(final_utts) != 0:
         
     | 
| 154 | 
         
            +
                        final_utts[-1] = final_utts[-1] + cur_utt
         
     | 
| 155 | 
         
            +
                    else:
         
     | 
| 156 | 
         
            +
                        final_utts.append(cur_utt)
         
     | 
| 157 | 
         
            +
             
     | 
| 158 | 
         
            +
                return final_utts
         
     | 
| 159 | 
         
            +
             
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
            +
            def is_only_punctuation(text: str):
         
     | 
| 162 | 
         
            +
                # Regular expression: Match strings that consist only of punctuation marks or are empty.
         
     | 
| 163 | 
         
            +
                punctuation_pattern = r"^[\p{P}\p{S}]*$"
         
     | 
| 164 | 
         
            +
                return bool(regex.fullmatch(punctuation_pattern, text))
         
     | 
| 165 | 
         
            +
             
     | 
| 166 | 
         
            +
             
     | 
| 167 | 
         
            +
            # spell Arabic numerals
         
     | 
| 168 | 
         
            +
            def spell_out_number(text: str, inflect_parser):
         
     | 
| 169 | 
         
            +
                new_text = []
         
     | 
| 170 | 
         
            +
                st = None
         
     | 
| 171 | 
         
            +
                for i, c in enumerate(text):
         
     | 
| 172 | 
         
            +
                    if not c.isdigit():
         
     | 
| 173 | 
         
            +
                        if st is not None:
         
     | 
| 174 | 
         
            +
                            num_str = inflect_parser.number_to_words(text[st:i])
         
     | 
| 175 | 
         
            +
                            new_text.append(num_str)
         
     | 
| 176 | 
         
            +
                            st = None
         
     | 
| 177 | 
         
            +
                        new_text.append(c)
         
     | 
| 178 | 
         
            +
                    else:
         
     | 
| 179 | 
         
            +
                        if st is None:
         
     | 
| 180 | 
         
            +
                            st = i
         
     | 
| 181 | 
         
            +
                if st is not None and st < len(text):
         
     | 
| 182 | 
         
            +
                    num_str = inflect_parser.number_to_words(text[st:])
         
     | 
| 183 | 
         
            +
                    new_text.append(num_str)
         
     | 
| 184 | 
         
            +
                return "".join(new_text)
         
     | 
| 185 | 
         
            +
             
     | 
| 186 | 
         
            +
             
     | 
| 187 | 
         
            +
            def remove_emoji(text: str):
         
     | 
| 188 | 
         
            +
                # Pattern to match emojis and their modifiers
         
     | 
| 189 | 
         
            +
                # - Standard emoji range
         
     | 
| 190 | 
         
            +
                # - Zero-width joiners (U+200D)
         
     | 
| 191 | 
         
            +
                # - Variation selectors (U+FE0F, U+FE0E)
         
     | 
| 192 | 
         
            +
                # - Skin tone modifiers (U+1F3FB to U+1F3FF)
         
     | 
| 193 | 
         
            +
                emoji_pattern = re.compile(
         
     | 
| 194 | 
         
            +
                    r"["
         
     | 
| 195 | 
         
            +
                    r"\U00010000-\U0010FFFF"  # Standard emoji range
         
     | 
| 196 | 
         
            +
                    r"\u200D"  # Zero-width joiner
         
     | 
| 197 | 
         
            +
                    r"\uFE0F\uFE0E"  # Variation selectors
         
     | 
| 198 | 
         
            +
                    r"\U0001F3FB-\U0001F3FF"  # Skin tone modifiers
         
     | 
| 199 | 
         
            +
                    r"]+",
         
     | 
| 200 | 
         
            +
                    flags=re.UNICODE,
         
     | 
| 201 | 
         
            +
                )
         
     | 
| 202 | 
         
            +
                return emoji_pattern.sub(r"", text)
         
     | 
| 203 | 
         
            +
             
     | 
| 204 | 
         
            +
             
     | 
| 205 | 
         
            +
            def remove_repeated_punctuations(text, punctuations):
         
     | 
| 206 | 
         
            +
                if len(punctuations) == 0:
         
     | 
| 207 | 
         
            +
                    return text
         
     | 
| 208 | 
         
            +
                pattern = f"[{re.escape(''.join(punctuations))}]"  # Create regex pattern for given punctuations
         
     | 
| 209 | 
         
            +
                return re.sub(rf"({pattern})\1+", r"\1", text)
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
             
     | 
| 212 | 
         
            +
            def full_to_half_width(text: str) -> str:
         
     | 
| 213 | 
         
            +
                """Convert full-width punctuation to half-width in a given string."""
         
     | 
| 214 | 
         
            +
                full_width = "!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
         
     | 
| 215 | 
         
            +
                half_width = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
         
     | 
| 216 | 
         
            +
                trans_table = str.maketrans(full_width, half_width)
         
     | 
| 217 | 
         
            +
                return text.translate(trans_table)
         
     | 
| 218 | 
         
            +
             
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
            def split_interleaved_delayed_audios(
         
     | 
| 221 | 
         
            +
                audio_data: Union[list[list[int]], torch.Tensor],
         
     | 
| 222 | 
         
            +
                audio_tokenizer: HiggsAudioTokenizer,
         
     | 
| 223 | 
         
            +
                audio_stream_eos_id: int,
         
     | 
| 224 | 
         
            +
            ) -> list[tuple[list[list[int]], torch.Tensor]]:
         
     | 
| 225 | 
         
            +
                separator = [audio_stream_eos_id] * audio_tokenizer.num_codebooks
         
     | 
| 226 | 
         
            +
             
     | 
| 227 | 
         
            +
                # Convert separator to numpy array if audio_data is numpy array
         
     | 
| 228 | 
         
            +
                if isinstance(audio_data, torch.Tensor):
         
     | 
| 229 | 
         
            +
                    audio_data = audio_data.transpose(1, 0)
         
     | 
| 230 | 
         
            +
                    separator = torch.tensor(separator)
         
     | 
| 231 | 
         
            +
                    # Find the indices where the rows equal the separator
         
     | 
| 232 | 
         
            +
                    split_indices = torch.where(torch.all(audio_data == separator, dim=1))[0]
         
     | 
| 233 | 
         
            +
                    start = 0
         
     | 
| 234 | 
         
            +
                    groups = []
         
     | 
| 235 | 
         
            +
                    for idx in split_indices:
         
     | 
| 236 | 
         
            +
                        groups.append(audio_data[start:idx].transpose(1, 0))
         
     | 
| 237 | 
         
            +
                        start = idx + 1
         
     | 
| 238 | 
         
            +
                    if start < len(audio_data):
         
     | 
| 239 | 
         
            +
                        groups.append(audio_data[start:].transpose(1, 0))
         
     | 
| 240 | 
         
            +
                else:
         
     | 
| 241 | 
         
            +
                    groups = []
         
     | 
| 242 | 
         
            +
                    current = []
         
     | 
| 243 | 
         
            +
                    for row in audio_data:
         
     | 
| 244 | 
         
            +
                        current.append(row)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                        if row == separator:
         
     | 
| 247 | 
         
            +
                            groups.append(current)
         
     | 
| 248 | 
         
            +
                            current = []
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                    # Don't forget the last group if there's no trailing separator
         
     | 
| 251 | 
         
            +
                    if current:
         
     | 
| 252 | 
         
            +
                        groups.append(current)
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                return groups
         
     | 
    	
        higgs_audio_utils.py
    ADDED
    
    | 
         @@ -0,0 +1,290 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Optional
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            # Import HiggsAudio components
         
     | 
| 4 | 
         
            +
            from higgs_audio.serve.serve_engine import HiggsAudioServeEngine
         
     | 
| 5 | 
         
            +
            from higgs_audio.data_types import ChatMLSample, AudioContent, Message
         
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            import base64
         
     | 
| 8 | 
         
            +
            from functools import lru_cache
         
     | 
| 9 | 
         
            +
            from loguru import logger
         
     | 
| 10 | 
         
            +
            import os
         
     | 
| 11 | 
         
            +
            import json
         
     | 
| 12 | 
         
            +
            import uuid
         
     | 
| 13 | 
         
            +
            import time
         
     | 
| 14 | 
         
            +
            import numpy as np
         
     | 
| 15 | 
         
            +
            import re
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            def process_text_output(text_output: str):
         
     | 
| 18 | 
         
            +
                # remove all the continuous <|AUDIO_OUT|> tokens with a single <|AUDIO_OUT|>
         
     | 
| 19 | 
         
            +
                text_output = re.sub(r"(<\|AUDIO_OUT\|>)+", r"<|AUDIO_OUT|>", text_output)
         
     | 
| 20 | 
         
            +
                return text_output
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
            def check_return_audio(audio_wv: np.ndarray):
         
     | 
| 24 | 
         
            +
                # check if the audio returned is all silent
         
     | 
| 25 | 
         
            +
                if np.all(audio_wv == 0):
         
     | 
| 26 | 
         
            +
                    logger.warning("Audio is silent, returning None")
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def load_voice_presets():
         
     | 
| 30 | 
         
            +
                """Load the voice presets from the voice_examples directory."""
         
     | 
| 31 | 
         
            +
                try:
         
     | 
| 32 | 
         
            +
                    with open(
         
     | 
| 33 | 
         
            +
                        os.path.join(os.path.dirname(__file__), "examples", "audios", "config.json"),
         
     | 
| 34 | 
         
            +
                        "r",
         
     | 
| 35 | 
         
            +
                    ) as f:
         
     | 
| 36 | 
         
            +
                        voice_dict = json.load(f)
         
     | 
| 37 | 
         
            +
                    voice_presets = {k: v for k, v in voice_dict.items()}
         
     | 
| 38 | 
         
            +
                    voice_presets["EMPTY"] = "No reference voice"
         
     | 
| 39 | 
         
            +
                    logger.info(f"Loaded voice presets: {list(voice_presets.keys())}")
         
     | 
| 40 | 
         
            +
                    return voice_presets
         
     | 
| 41 | 
         
            +
                except FileNotFoundError:
         
     | 
| 42 | 
         
            +
                    logger.warning("Voice examples config file not found. Using empty voice presets.")
         
     | 
| 43 | 
         
            +
                    return {"EMPTY": "No reference voice"}
         
     | 
| 44 | 
         
            +
                except Exception as e:
         
     | 
| 45 | 
         
            +
                    logger.error(f"Error loading voice presets: {e}")
         
     | 
| 46 | 
         
            +
                    return {"EMPTY": "No reference voice"}
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
            SAMPLE_RATE = 24000
         
     | 
| 50 | 
         
            +
            DEFAULT_STOP_STRINGS = ["<|end_of_text|>", "<|eot_id|>"]
         
     | 
| 51 | 
         
            +
            VOICE_PRESETS = load_voice_presets()
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            def initialize_engine(model_path, audio_tokenizer_path) -> bool:
         
     | 
| 55 | 
         
            +
                engine = HiggsAudioServeEngine(
         
     | 
| 56 | 
         
            +
                    model_name_or_path=model_path,
         
     | 
| 57 | 
         
            +
                    audio_tokenizer_name_or_path=audio_tokenizer_path,
         
     | 
| 58 | 
         
            +
                    device="cuda",
         
     | 
| 59 | 
         
            +
                )
         
     | 
| 60 | 
         
            +
                return engine
         
     | 
| 61 | 
         
            +
             
     | 
| 62 | 
         
            +
            def get_voice_preset(voice_preset):
         
     | 
| 63 | 
         
            +
                """Get the voice path and text for a given voice preset."""
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                preset_dir = os.path.join(os.path.dirname(__file__), "examples", "audios")
         
     | 
| 66 | 
         
            +
                voice_path = os.path.join(preset_dir, VOICE_PRESETS[voice_preset]["audio_file"])
         
     | 
| 67 | 
         
            +
                
         
     | 
| 68 | 
         
            +
                if not os.path.exists(voice_path):
         
     | 
| 69 | 
         
            +
                    logger.warning(f"Voice preset file not found: {voice_path}")
         
     | 
| 70 | 
         
            +
                    return None, "Voice preset not found"
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
                text = VOICE_PRESETS[voice_preset]["transcript"]
         
     | 
| 73 | 
         
            +
                return voice_path, text
         
     | 
| 74 | 
         
            +
             
     | 
| 75 | 
         
            +
             
     | 
| 76 | 
         
            +
            def normalize_chinese_punctuation(text):
         
     | 
| 77 | 
         
            +
                """
         
     | 
| 78 | 
         
            +
                Convert Chinese (full-width) punctuation marks to English (half-width) equivalents.
         
     | 
| 79 | 
         
            +
                """
         
     | 
| 80 | 
         
            +
                # Mapping of Chinese punctuation to English punctuation
         
     | 
| 81 | 
         
            +
                chinese_to_english_punct = {
         
     | 
| 82 | 
         
            +
                    ",": ", ",  # comma
         
     | 
| 83 | 
         
            +
                    "。": ".",  # period
         
     | 
| 84 | 
         
            +
                    ":": ":",  # colon
         
     | 
| 85 | 
         
            +
                    ";": ";",  # semicolon
         
     | 
| 86 | 
         
            +
                    "?": "?",  # question mark
         
     | 
| 87 | 
         
            +
                    "!": "!",  # exclamation mark
         
     | 
| 88 | 
         
            +
                    "(": "(",  # left parenthesis
         
     | 
| 89 | 
         
            +
                    ")": ")",  # right parenthesis
         
     | 
| 90 | 
         
            +
                    "【": "[",  # left square bracket
         
     | 
| 91 | 
         
            +
                    "】": "]",  # right square bracket
         
     | 
| 92 | 
         
            +
                    "《": "<",  # left angle quote
         
     | 
| 93 | 
         
            +
                    "》": ">",  # right angle quote
         
     | 
| 94 | 
         
            +
                    "“": '"',  # left double quotation
         
     | 
| 95 | 
         
            +
                    "”": '"',  # right double quotation
         
     | 
| 96 | 
         
            +
                    "‘": "'",  # left single quotation
         
     | 
| 97 | 
         
            +
                    "’": "'",  # right single quotation
         
     | 
| 98 | 
         
            +
                    "、": ",",  # enumeration comma
         
     | 
| 99 | 
         
            +
                    "—": "-",  # em dash
         
     | 
| 100 | 
         
            +
                    "…": "...",  # ellipsis
         
     | 
| 101 | 
         
            +
                    "·": ".",  # middle dot
         
     | 
| 102 | 
         
            +
                    "「": '"',  # left corner bracket
         
     | 
| 103 | 
         
            +
                    "」": '"',  # right corner bracket
         
     | 
| 104 | 
         
            +
                    "『": '"',  # left double corner bracket
         
     | 
| 105 | 
         
            +
                    "』": '"',  # right double corner bracket
         
     | 
| 106 | 
         
            +
                }
         
     | 
| 107 | 
         
            +
             
     | 
| 108 | 
         
            +
                # Replace each Chinese punctuation with its English counterpart
         
     | 
| 109 | 
         
            +
                for zh_punct, en_punct in chinese_to_english_punct.items():
         
     | 
| 110 | 
         
            +
                    text = text.replace(zh_punct, en_punct)
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
                return text
         
     | 
| 113 | 
         
            +
             
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            def normalize_text(transcript: str):
         
     | 
| 116 | 
         
            +
                transcript = normalize_chinese_punctuation(transcript)
         
     | 
| 117 | 
         
            +
                # Other normalizations (e.g., parentheses and other symbols. Will be improved in the future)
         
     | 
| 118 | 
         
            +
                transcript = transcript.replace("(", " ")
         
     | 
| 119 | 
         
            +
                transcript = transcript.replace(")", " ")
         
     | 
| 120 | 
         
            +
                transcript = transcript.replace("°F", " degrees Fahrenheit")
         
     | 
| 121 | 
         
            +
                transcript = transcript.replace("°C", " degrees Celsius")
         
     | 
| 122 | 
         
            +
             
     | 
| 123 | 
         
            +
                for tag, replacement in [
         
     | 
| 124 | 
         
            +
                    ("[laugh]", "<SE>[Laughter]</SE>"),
         
     | 
| 125 | 
         
            +
                    ("[humming start]", "<SE>[Humming]</SE>"),
         
     | 
| 126 | 
         
            +
                    ("[humming end]", "<SE_e>[Humming]</SE_e>"),
         
     | 
| 127 | 
         
            +
                    ("[music start]", "<SE_s>[Music]</SE_s>"),
         
     | 
| 128 | 
         
            +
                    ("[music end]", "<SE_e>[Music]</SE_e>"),
         
     | 
| 129 | 
         
            +
                    ("[music]", "<SE>[Music]</SE>"),
         
     | 
| 130 | 
         
            +
                    ("[sing start]", "<SE_s>[Singing]</SE_s>"),
         
     | 
| 131 | 
         
            +
                    ("[sing end]", "<SE_e>[Singing]</SE_e>"),
         
     | 
| 132 | 
         
            +
                    ("[applause]", "<SE>[Applause]</SE>"),
         
     | 
| 133 | 
         
            +
                    ("[cheering]", "<SE>[Cheering]</SE>"),
         
     | 
| 134 | 
         
            +
                    ("[cough]", "<SE>[Cough]</SE>"),
         
     | 
| 135 | 
         
            +
                ]:
         
     | 
| 136 | 
         
            +
                    transcript = transcript.replace(tag, replacement)
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
                lines = transcript.split("\n")
         
     | 
| 139 | 
         
            +
                transcript = "\n".join([" ".join(line.split()) for line in lines if line.strip()])
         
     | 
| 140 | 
         
            +
                transcript = transcript.strip()
         
     | 
| 141 | 
         
            +
             
     | 
| 142 | 
         
            +
                if not any([transcript.endswith(c) for c in [".", "!", "?", ",", ";", '"', "'", "</SE_e>", "</SE>"]]):
         
     | 
| 143 | 
         
            +
                    transcript += "."
         
     | 
| 144 | 
         
            +
             
     | 
| 145 | 
         
            +
                return transcript
         
     | 
| 146 | 
         
            +
             
     | 
| 147 | 
         
            +
            @lru_cache(maxsize=20)
         
     | 
| 148 | 
         
            +
            def encode_audio_file(file_path):
         
     | 
| 149 | 
         
            +
                """Encode an audio file to base64."""
         
     | 
| 150 | 
         
            +
                with open(file_path, "rb") as audio_file:
         
     | 
| 151 | 
         
            +
                    return base64.b64encode(audio_file.read()).decode("utf-8")
         
     | 
| 152 | 
         
            +
             
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
            def prepare_chatml_sample(
         
     | 
| 155 | 
         
            +
                voice_preset: str,
         
     | 
| 156 | 
         
            +
                text: str,
         
     | 
| 157 | 
         
            +
                reference_audio: Optional[str] = None,
         
     | 
| 158 | 
         
            +
                reference_text: Optional[str] = None,
         
     | 
| 159 | 
         
            +
                system_prompt: str = "",
         
     | 
| 160 | 
         
            +
            ):
         
     | 
| 161 | 
         
            +
                """Prepare a ChatMLSample for the HiggsAudioServeEngine."""
         
     | 
| 162 | 
         
            +
                messages = []
         
     | 
| 163 | 
         
            +
             
     | 
| 164 | 
         
            +
                # Add system message if provided
         
     | 
| 165 | 
         
            +
                if len(system_prompt) > 0:
         
     | 
| 166 | 
         
            +
                    messages.append(Message(role="system", content=system_prompt))
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                # Add reference audio if provided
         
     | 
| 169 | 
         
            +
                audio_base64 = None
         
     | 
| 170 | 
         
            +
                ref_text = ""
         
     | 
| 171 | 
         
            +
             
     | 
| 172 | 
         
            +
                if reference_audio:
         
     | 
| 173 | 
         
            +
                    # Custom reference audio
         
     | 
| 174 | 
         
            +
                    audio_base64 = encode_audio_file(reference_audio)
         
     | 
| 175 | 
         
            +
                    ref_text = reference_text or ""
         
     | 
| 176 | 
         
            +
                elif voice_preset != "EMPTY":
         
     | 
| 177 | 
         
            +
                    # Voice preset
         
     | 
| 178 | 
         
            +
                    voice_path, ref_text = get_voice_preset(voice_preset)
         
     | 
| 179 | 
         
            +
                    if voice_path is None:
         
     | 
| 180 | 
         
            +
                        logger.warning(f"Voice preset {voice_preset} not found, skipping reference audio")
         
     | 
| 181 | 
         
            +
                    else:
         
     | 
| 182 | 
         
            +
                        audio_base64 = encode_audio_file(voice_path)
         
     | 
| 183 | 
         
            +
             
     | 
| 184 | 
         
            +
                # Only add reference audio if we have it
         
     | 
| 185 | 
         
            +
                if audio_base64 is not None:
         
     | 
| 186 | 
         
            +
                    # Add user message with reference text
         
     | 
| 187 | 
         
            +
                    messages.append(Message(role="user", content=ref_text))
         
     | 
| 188 | 
         
            +
             
     | 
| 189 | 
         
            +
                    # Add assistant message with audio content
         
     | 
| 190 | 
         
            +
                    audio_content = AudioContent(raw_audio=audio_base64, audio_url="")
         
     | 
| 191 | 
         
            +
                    messages.append(Message(role="assistant", content=[audio_content]))
         
     | 
| 192 | 
         
            +
             
     | 
| 193 | 
         
            +
                # Add the main user message
         
     | 
| 194 | 
         
            +
                text = normalize_text(text)
         
     | 
| 195 | 
         
            +
                messages.append(Message(role="user", content=text))
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                return ChatMLSample(messages=messages)
         
     | 
| 198 | 
         
            +
             
     | 
| 199 | 
         
            +
             
     | 
| 200 | 
         
            +
             
     | 
| 201 | 
         
            +
            def text_to_speech(
         
     | 
| 202 | 
         
            +
                engine,
         
     | 
| 203 | 
         
            +
                text,
         
     | 
| 204 | 
         
            +
                system_prompt="",
         
     | 
| 205 | 
         
            +
                voice_preset="EMPTY",
         
     | 
| 206 | 
         
            +
                reference_audio=None,
         
     | 
| 207 | 
         
            +
                reference_text=None,
         
     | 
| 208 | 
         
            +
                max_completion_tokens=1024,
         
     | 
| 209 | 
         
            +
                temperature=1.0,
         
     | 
| 210 | 
         
            +
                top_p=0.95,
         
     | 
| 211 | 
         
            +
                top_k=50,
         
     | 
| 212 | 
         
            +
                stop_strings=None,
         
     | 
| 213 | 
         
            +
                ras_win_len=7,
         
     | 
| 214 | 
         
            +
                ras_win_max_num_repeat=2,
         
     | 
| 215 | 
         
            +
            ):
         
     | 
| 216 | 
         
            +
                """
         
     | 
| 217 | 
         
            +
                Convert text to speech using HiggsAudioServeEngine.
         
     | 
| 218 | 
         
            +
                
         
     | 
| 219 | 
         
            +
                Args:
         
     | 
| 220 | 
         
            +
                    text: The text to convert to speech
         
     | 
| 221 | 
         
            +
                    voice_preset: The voice preset to use (or "EMPTY" for no preset)
         
     | 
| 222 | 
         
            +
                    reference_audio: Optional path to reference audio file
         
     | 
| 223 | 
         
            +
                    reference_text: Optional transcript of the reference audio
         
     | 
| 224 | 
         
            +
                    max_completion_tokens: Maximum number of tokens to generate
         
     | 
| 225 | 
         
            +
                    temperature: Sampling temperature for generation
         
     | 
| 226 | 
         
            +
                    top_p: Top-p sampling parameter
         
     | 
| 227 | 
         
            +
                    top_k: Top-k sampling parameter
         
     | 
| 228 | 
         
            +
                    system_prompt: System prompt to guide the model
         
     | 
| 229 | 
         
            +
                    stop_strings: Dataframe containing stop strings
         
     | 
| 230 | 
         
            +
                    ras_win_len: Window length for repetition avoidance sampling
         
     | 
| 231 | 
         
            +
                    ras_win_max_num_repeat: Maximum number of repetitions allowed in the window
         
     | 
| 232 | 
         
            +
                    
         
     | 
| 233 | 
         
            +
                Returns:
         
     | 
| 234 | 
         
            +
                    Tuple of (generated_text, (sample_rate, audio_data)) where audio_data is int16 numpy array
         
     | 
| 235 | 
         
            +
                """
         
     | 
| 236 | 
         
            +
             
     | 
| 237 | 
         
            +
                print(f'text is={text}')
         
     | 
| 238 | 
         
            +
                print(f'voice_preset is={voice_preset}')
         
     | 
| 239 | 
         
            +
                print(f'reference_audio is={reference_audio}')
         
     | 
| 240 | 
         
            +
                print(f'reference_text is={reference_text}')
         
     | 
| 241 | 
         
            +
             
     | 
| 242 | 
         
            +
                try:
         
     | 
| 243 | 
         
            +
                    # Prepare ChatML sample
         
     | 
| 244 | 
         
            +
                    chatml_sample = prepare_chatml_sample(voice_preset, text, reference_audio, reference_text, system_prompt)
         
     | 
| 245 | 
         
            +
             
     | 
| 246 | 
         
            +
                    # Convert stop strings format
         
     | 
| 247 | 
         
            +
                    if stop_strings is None:
         
     | 
| 248 | 
         
            +
                        stop_list = DEFAULT_STOP_STRINGS
         
     | 
| 249 | 
         
            +
                    else:
         
     | 
| 250 | 
         
            +
                        stop_list = [s for s in stop_strings["stops"] if s.strip()]
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                    request_id = f"tts-playground-{str(uuid.uuid4())}"
         
     | 
| 253 | 
         
            +
                    logger.info(
         
     | 
| 254 | 
         
            +
                        f"{request_id}: Generating speech for text: {text[:100]}..., \n"
         
     | 
| 255 | 
         
            +
                        f"with parameters: temperature={temperature}, top_p={top_p}, top_k={top_k}, stop_list={stop_list}, "
         
     | 
| 256 | 
         
            +
                        f"ras_win_len={ras_win_len}, ras_win_max_num_repeat={ras_win_max_num_repeat}"
         
     | 
| 257 | 
         
            +
                    )
         
     | 
| 258 | 
         
            +
                    start_time = time.time()
         
     | 
| 259 | 
         
            +
             
     | 
| 260 | 
         
            +
                    # Generate using the engine
         
     | 
| 261 | 
         
            +
                    response = engine.generate(
         
     | 
| 262 | 
         
            +
                        chat_ml_sample=chatml_sample,
         
     | 
| 263 | 
         
            +
                        max_new_tokens=max_completion_tokens,
         
     | 
| 264 | 
         
            +
                        temperature=temperature,
         
     | 
| 265 | 
         
            +
                        top_k=top_k if top_k > 0 else None,
         
     | 
| 266 | 
         
            +
                        top_p=top_p,
         
     | 
| 267 | 
         
            +
                        stop_strings=stop_list,
         
     | 
| 268 | 
         
            +
                        ras_win_len=ras_win_len if ras_win_len > 0 else None,
         
     | 
| 269 | 
         
            +
                        ras_win_max_num_repeat=max(ras_win_len, ras_win_max_num_repeat),
         
     | 
| 270 | 
         
            +
                    )
         
     | 
| 271 | 
         
            +
             
     | 
| 272 | 
         
            +
                    generation_time = time.time() - start_time
         
     | 
| 273 | 
         
            +
                    logger.info(f"{request_id}: Generated audio in {generation_time:.3f} seconds")
         
     | 
| 274 | 
         
            +
             
     | 
| 275 | 
         
            +
                    # Process the response
         
     | 
| 276 | 
         
            +
                    text_output = process_text_output(response.generated_text)
         
     | 
| 277 | 
         
            +
             
     | 
| 278 | 
         
            +
                    if response.audio is not None:
         
     | 
| 279 | 
         
            +
                        # Convert to int16 for Gradio
         
     | 
| 280 | 
         
            +
                        audio_data = (response.audio * 32767).astype(np.int16)
         
     | 
| 281 | 
         
            +
                        check_return_audio(audio_data)
         
     | 
| 282 | 
         
            +
                        return text_output, (response.sampling_rate, audio_data)
         
     | 
| 283 | 
         
            +
                    else:
         
     | 
| 284 | 
         
            +
                        logger.warning("No audio generated")
         
     | 
| 285 | 
         
            +
                        return text_output, None
         
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
                except Exception as e:
         
     | 
| 288 | 
         
            +
                    error_msg = f"Error generating speech: {e}"
         
     | 
| 289 | 
         
            +
                    logger.error(error_msg)
         
     | 
| 290 | 
         
            +
                    return f"❌ {error_msg}", None
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 1 | 
         
             
            tqdm
         
     | 
| 2 | 
         
             
            librosa==0.10.2.post1
         
     | 
| 3 | 
         
             
            peft==0.15.1
         
     | 
| 4 | 
         
            -
            transformers 
     | 
| 5 | 
         
             
            scipy==1.14.0
         
     | 
| 6 | 
         
             
            numpy==1.26.4 
         
     | 
| 7 | 
         
             
            xfuser==0.4.1
         
     | 
| 
         @@ -9,4 +9,19 @@ ftfy 
     | 
|
| 9 | 
         
             
            einops
         
     | 
| 10 | 
         
             
            omegaconf
         
     | 
| 11 | 
         
             
            torchvision
         
     | 
| 12 | 
         
            -
            ninja
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            tqdm
         
     | 
| 2 | 
         
             
            librosa==0.10.2.post1
         
     | 
| 3 | 
         
             
            peft==0.15.1
         
     | 
| 4 | 
         
            +
            transformers>=4.45.1,<4.47.0
         
     | 
| 5 | 
         
             
            scipy==1.14.0
         
     | 
| 6 | 
         
             
            numpy==1.26.4 
         
     | 
| 7 | 
         
             
            xfuser==0.4.1
         
     | 
| 
         | 
|
| 9 | 
         
             
            einops
         
     | 
| 10 | 
         
             
            omegaconf
         
     | 
| 11 | 
         
             
            torchvision
         
     | 
| 12 | 
         
            +
            ninja
         
     | 
| 13 | 
         
            +
            gradio_extendedaudio @ https://github.com/OutofAi/gradio-extendedaudio/releases/download/0.0.1/gradio_extendedaudio-0.0.1-py3-none-any.whl
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
            dacite
         
     | 
| 16 | 
         
            +
            boto3==1.35.36
         
     | 
| 17 | 
         
            +
            s3fs
         
     | 
| 18 | 
         
            +
            json_repair
         
     | 
| 19 | 
         
            +
            pandas
         
     | 
| 20 | 
         
            +
            pydantic
         
     | 
| 21 | 
         
            +
            vector_quantize_pytorch
         
     | 
| 22 | 
         
            +
            loguru
         
     | 
| 23 | 
         
            +
            pydub
         
     | 
| 24 | 
         
            +
            ruff==0.12.2
         
     | 
| 25 | 
         
            +
            click
         
     | 
| 26 | 
         
            +
            torchaudio
         
     | 
| 27 | 
         
            +
            descript-audio-codec
         
     |