Den Pavloff
time report
52c0d1f
raw
history blame
10.9 kB
import torch
import librosa
import requests
import time
from nemo.collections.tts.models import AudioCodecModel
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
@dataclass
class Config:
model_name: str = "nineninesix/lfm-nano-codec-tts-exp-4-large-61468-st"
audiocodec_name: str = "nvidia/nemo-nano-codec-22khz-0.6kbps-12.5fps"
device_map: str = "auto"
tokeniser_length: int = 64400
start_of_text: int = 1
end_of_text: int = 2
max_new_tokens: int = 2000
temperature: float = .6
top_p: float = .95
repetition_penalty: float = 1.1
class NemoAudioPlayer:
def __init__(self, config, text_tokenizer_name: str = None) -> None:
self.conf = config
print(f"Loading NeMo codec model: {self.conf.audiocodec_name}")
# Load NeMo codec model
self.nemo_codec_model = AudioCodecModel.from_pretrained(
self.conf.audiocodec_name
).eval()
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Moving NeMo codec to device: {self.device}")
self.nemo_codec_model.to(self.device)
self.text_tokenizer_name = text_tokenizer_name
if self.text_tokenizer_name:
self.tokenizer = AutoTokenizer.from_pretrained(self.text_tokenizer_name)
# Token configuration
self.tokeniser_length = self.conf.tokeniser_length
self.start_of_text = self.conf.start_of_text
self.end_of_text = self.conf.end_of_text
self.start_of_speech = self.tokeniser_length + 1
self.end_of_speech = self.tokeniser_length + 2
self.start_of_human = self.tokeniser_length + 3
self.end_of_human = self.tokeniser_length + 4
self.start_of_ai = self.tokeniser_length + 5
self.end_of_ai = self.tokeniser_length + 6
self.pad_token = self.tokeniser_length + 7
self.audio_tokens_start = self.tokeniser_length + 10
self.codebook_size = 4032
def output_validation(self, out_ids):
"""Validate that output contains required speech tokens"""
start_of_speech_flag = self.start_of_speech in out_ids
end_of_speech_flag = self.end_of_speech in out_ids
if not (start_of_speech_flag and end_of_speech_flag):
raise ValueError('Special speech tokens not found in output!')
def get_nano_codes(self, out_ids):
"""Extract nano codec tokens from model output"""
try:
start_a_idx = (out_ids == self.start_of_speech).nonzero(as_tuple=True)[0].item()
end_a_idx = (out_ids == self.end_of_speech).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Speech start/end tokens not found!')
if start_a_idx >= end_a_idx:
raise ValueError('Invalid audio codes sequence!')
audio_codes = out_ids[start_a_idx + 1: end_a_idx]
if len(audio_codes) % 4:
raise ValueError('Audio codes length must be multiple of 4!')
audio_codes = audio_codes.reshape(-1, 4)
# Decode audio codes
audio_codes = audio_codes - torch.tensor([self.codebook_size * i for i in range(4)])
audio_codes = audio_codes - self.audio_tokens_start
if (audio_codes < 0).sum().item() > 0:
raise ValueError('Invalid audio tokens detected!')
audio_codes = audio_codes.T.unsqueeze(0)
len_ = torch.tensor([audio_codes.shape[-1]])
return audio_codes, len_
def get_text(self, out_ids):
"""Extract text from model output"""
try:
start_t_idx = (out_ids == self.start_of_text).nonzero(as_tuple=True)[0].item()
end_t_idx = (out_ids == self.end_of_text).nonzero(as_tuple=True)[0].item()
except IndexError:
raise ValueError('Text start/end tokens not found!')
txt_tokens = out_ids[start_t_idx: end_t_idx + 1]
text = self.tokenizer.decode(txt_tokens, skip_special_tokens=True)
return text
def get_waveform(self, out_ids):
"""Convert model output to audio waveform"""
out_ids = out_ids.flatten()
# Validate output
self.output_validation(out_ids)
# Extract audio codes
audio_codes, len_ = self.get_nano_codes(out_ids)
audio_codes, len_ = audio_codes.to(self.device), len_.to(self.device)
with torch.inference_mode():
reconstructed_audio, _ = self.nemo_codec_model.decode(
tokens=audio_codes,
tokens_len=len_
)
output_audio = reconstructed_audio.cpu().detach().numpy().squeeze()
if self.text_tokenizer_name:
text = self.get_text(out_ids)
return output_audio, text
else:
return output_audio, None
class KaniModel:
def __init__(self, config, player: NemoAudioPlayer, token: str) -> None:
self.conf = config
self.player = player
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading model: {self.conf.model_name}")
print(f"Target device: {self.device}")
# Load model with proper configuration
self.model = AutoModelForCausalLM.from_pretrained(
self.conf.model_name,
torch_dtype=torch.bfloat16,
device_map=self.conf.device_map,
token=token,
trust_remote_code=True # May be needed for some models
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.conf.model_name,
token=token,
trust_remote_code=True
)
print(f"Model loaded successfully on device: {next(self.model.parameters()).device}")
def get_input_ids(self, text_prompt: str) -> tuple[torch.tensor]:
"""Prepare input tokens for the model"""
START_OF_HUMAN = self.player.start_of_human
END_OF_TEXT = self.player.end_of_text
END_OF_HUMAN = self.player.end_of_human
# Tokenize input text
input_ids = self.tokenizer(text_prompt, return_tensors="pt").input_ids
# Add special tokens
start_token = torch.tensor([[START_OF_HUMAN]], dtype=torch.int64)
end_tokens = torch.tensor([[END_OF_TEXT, END_OF_HUMAN]], dtype=torch.int64)
# Concatenate tokens
modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
attention_mask = torch.ones(1, modified_input_ids.shape[1], dtype=torch.int64)
return modified_input_ids, attention_mask
def model_request(self, input_ids: torch.tensor, attention_mask: torch.tensor) -> torch.tensor:
"""Generate tokens using the model"""
input_ids = input_ids.to(self.device)
attention_mask = attention_mask.to(self.device)
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=self.conf.max_new_tokens,
do_sample=True,
temperature=self.conf.temperature,
top_p=self.conf.top_p,
repetition_penalty=self.conf.repetition_penalty,
num_return_sequences=1,
eos_token_id=self.player.end_of_speech,
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id
)
return generated_ids.to('cpu')
def time_report(self, point_1, point_2, point_3):
model_request = point_2 - point_1
player_time = point_3 - point_2
total_time = point_3 - point_1
report = f"MODEL GENERATION: {model_request:.2f}\nNANO CODEC: {player_time:.2f}\nTOTAL: {total_time:.2f}"
return report
def run_model(self, text: str):
"""Complete pipeline: text -> tokens -> generation -> audio"""
# Prepare input
input_ids, attention_mask = self.get_input_ids(text)
# Generate tokens
point_1 = time.time()
model_output = self.model_request(input_ids, attention_mask)
# Convert to audio
point_2 = time.time()
audio, _ = self.player.get_waveform(model_output)
point_3 = time.time()
return audio, text, self.time_report(point_1, point_2, point_3)
class Demo:
def __init__(self):
self.audio_dir = './audio_examples'
os.makedirs(self.audio_dir, exist_ok=True)
self.sentences = [
"You make my days brighter, and my wildest dreams feel like reality. How do you do that?",
"Anyway, um, so, um, tell me, tell me all about her. I mean, what's she like? Is she really, you know, pretty?",
"Great, and just a couple quick questions so we can match you with the right buyer. Is your home address still 330 East Charleston Road?",
"No, that does not make you a failure. No, sweetie, no. It just, uh, it just means that you're having a tough time...",
"Oh, yeah. I mean did you want to get a quick snack together or maybe something before you go?",
"I-- Oh, I am such an idiot sometimes. I'm so sorry. Um, I-I don't know where my head's at.",
"Got it. $300,000. I can definitely help you get a very good price for your property by selecting a realtor.",
"Holy fu- Oh my God! Don't you understand how dangerous it is, huh?"
]
self.urls = [
'https://www.nineninesix.ai/examples/kani/1.wav',
'https://www.nineninesix.ai/examples/kani/2.wav',
'https://www.nineninesix.ai/examples/kani/5.wav',
'https://www.nineninesix.ai/examples/kani/6.wav',
'https://www.nineninesix.ai/examples/kani/3.wav',
'https://www.nineninesix.ai/examples/kani/7.wav',
'https://www.nineninesix.ai/examples/kani/4.wav',
'https://www.nineninesix.ai/examples/kani/8.wav'
]
def download_audio(self, url: str, filename: str):
filepath = os.path.join(self.audio_dir, filename)
if not os.path.exists(filepath):
r = requests.get(url)
r.raise_for_status()
with open(filepath, 'wb') as f:
f.write(r.content)
return filepath
def get_audio(self, filepath: str):
arr, _ = librosa.load(filepath, sr=22050)
return arr
def __call__(self):
examples = {}
for idx, (sentence, url) in enumerate(zip(self.sentences, self.urls), start=1):
filename = f"{idx}.wav"
filepath = self.download_audio(url, filename)
examples[sentence] = self.get_audio(filepath)
return examples