cast-0.7b-s2s / README.md
mrohan's picture
Update README.md
c8cc6e3 verified
metadata
gated: true
extra_gated_prompt: >
  **Terms of Academic, Non-Commercial Use**


  I ACKNOWLEDGE THAT I HAVE READ [THE LICENSE AGREEMENT](./LICENSE), UNDERSTAND
  IT AND AGREE TO BE BOUND BY ITS TERMS AND CONDITIONS.


  It is important that you read the entirety of and understand the License
  Agreement. There are, however, a few key points that we need to emphasize
  again:


  - **ACADEMIC, NON-COMMERCIAL USE**: The license granted is for academic,
  non-commercial purposes only. The term "academic, non-commercial" means
  academic or other scholarly research which (a) is not undertaken for any
  direct or indirect for-profit purposes, and (b) is not intended to produce
  works, services, or data for commercial use.  

  - **INTERNAL USE**: The license granted is for your own internal use only. You
  are not allowed to sublicense, distribute, transfer, disclose or make
  available the software to any third party.  

  - **NO WARRANTY**: The Software is provided "as is" and any express or implied
  warranties are disclaimed.  
extra_gated_button_content: Agree and access
pipeline_tag: audio-to-audio

CAST 0.7B — Speech-to-Speech

arXiv Demo Codec Dependency

Final checkpoint files. Depends on KrauthammerLab/cast-wavtokenizer-24k-40tps for encode/decode.

CAST 0.7B — Speech-to-Speech Model

CAST 0.7B is a speech-to-speech language model built on a 0.7B parameter Gemma3-style LM.
It generates natural continuations of spoken audio.

It requires the companion CAST WavTokenizer for encode/decode.


Demo

Interactive samples and usage examples: https://mortezaro.github.io/speech-cast/


Paper

Optimizing Speech Language Models for Acoustic Consistency
arXiv: 2509.26276https://arxiv.org/abs/2509.26276

We study speech language models that incorporate semantic initialization and planning losses to achieve robust and consistent generation. Our approach initializes speech tokens with self-supervised features, applies a light alignment loss, and trains with thinning and auxiliary objectives that target robustness and content planning. We train three models: a 0.7B speech-only model, a 1.0B speech-only model, and a 1.0B interleaved model with both text and speech. Acoustic studies show that the speech-only models achieve the highest consistency across speaker, gender, sentiment, room, and background factors, surpassing larger systems. Interleaving improves lexical and syntactic probes and semantic–acoustic alignment but reduces consistency. Linear probes show that our initialization biases the model toward content structure while trading off prosody detail. These results show that LM-side design and training mix control the balance between acoustic stability and semantic grounding without changes to the tokenizer or runtime architecture. A demo and model weights are available for exploration.

Installation

pip install torch torchaudio transformers accelerate soundfile
pip install git+https://github.com/jishengpeng/WavTokenizer.git

1- Resynthesis


WT_REPO = "KrauthammerLab/cast-wavtokenizer-24k-40tps"

# Download tokenizer ckpt + config from HF
wt_ckpt = hf_hub_download(WT_REPO, filename="wavtokenizer_large_unify_600_24k.ckpt")
try:
    wt_cfg  = hf_hub_download(WT_REPO, filename="config.yaml")
except Exception:
    wt_cfg = None  # cfg optional in your setup

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load WavTokenizer (codec)
if wt_cfg is not None:
    wt = WavTokenizer.from_pretrained0802(wt_cfg, wt_ckpt).to(device)
else:
    wt = WavTokenizer.from_pretrained0802(None, wt_ckpt).to(device)

# Load a 16 kHz prompt
wav16, sr = torchaudio.load("prompt_16k.wav")  # mono recommended
assert sr == 16000, f"Expected 16k input, got {sr}"

# (Optional) ensure mono
if wav16.size(0) > 1:
    wav16 = wav16.mean(dim=0, keepdim=True)

# Resample 16k -> 24k before encode (your pipeline runs at 24k)
wav24 = torchaudio.functional.resample(wav16, orig_freq=16000, new_freq=24000).to(device)

# Encode → features, codes
bandwidth_id = torch.tensor([0], device=device)
feats, codes = wt.encode_infer(wav24, bandwidth_id=bandwidth_id)  # feats: [1, ?, T], codes: [1, streams?, T] or [1, T]

# Decode back to waveform (24 kHz)
recon24 = wt.decode(feats, bandwidth_id=bandwidth_id)  # [1, T] or [1,1,T]
if recon24.dim() == 3:
    recon24 = recon24.squeeze(0)

# Save 24k round-trip audio
sf.write("recon_24k.wav", recon24.squeeze(0).detach().cpu().numpy(), 24000)
print("Wrote recon_24k.wav")
  1. Speech generation


LM_REPO = "KrauthammerLab/cast-0.7b-s2s"
WT_REPO = "KrauthammerLab/cast-wavtokenizer-24k-40tps"

device = "cuda" if torch.cuda.is_available() else "cpu"
codes_per_second = 40          # your setup: ~40 tokens/s
codebook_size    = 4096        # [Sp1]..[Sp4096]
speech_prefix     = "[Speech]"

# ---------- helpers ----------
def equal_power_crossfade(prev_24k: torch.Tensor, cont_24k: torch.Tensor, fade_ms: int = 40, sr: int = 24000) -> torch.Tensor:
    """Equal-power crossfade between prev and cont (both [1,T] @ 24k)."""
    fade = max(1, int(sr * fade_ms / 1000))
    prev_24k = prev_24k.to(device)
    cont_24k = cont_24k.to(device)
    if prev_24k.size(1) < fade or cont_24k.size(1) < fade:
        return torch.cat([prev_24k, cont_24k], dim=1)
    a = prev_24k[:, -fade:]
    b = cont_24k[:, :fade]
    t = torch.linspace(0, 1, fade, device=device).view(1, -1)
    mix = torch.cos(t * 0.5 * math.pi) * a + torch.sin(t * 0.5 * math.pi) * b
    return torch.cat([prev_24k[:, :-fade], mix, cont_24k[:, fade:]], dim=1)

class SpeechOnlyLogitsProcessor(LogitsProcessor):
    """Mask logits so only [Sp#] tokens (and EOS) can be sampled."""
    def __init__(self, allowed: Set[int]):
        super().__init__()
        self.allowed = list(allowed)

    def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
        mask = torch.full_like(scores, float("-inf"))
        mask[..., self.allowed] = 0.0
        return scores + mask

# ---------- load codec ----------
wt_ckpt = hf_hub_download(WT_REPO, filename="wavtokenizer_large_unify_600_24k.ckpt")
try:
    wt_cfg  = hf_hub_download(WT_REPO, filename="config.yaml")
except Exception:
    wt_cfg = None

if wt_cfg is not None:
    wt = WavTokenizer.from_pretrained0802(wt_cfg, wt_ckpt).to(device)
else:
    wt = WavTokenizer.from_pretrained0802(None, wt_ckpt).to(device)

# ---------- load LM + tokenizer ----------
tok = AutoTokenizer.from_pretrained(LM_REPO)
lm  = AutoModelForCausalLM.from_pretrained(LM_REPO, torch_dtype=torch.bfloat16).to(device).eval()

# Build speech token id table: "[Sp1]".."[Sp4096]" must be single tokens
speech_token_ids: List[int] = []
for i in range(1, codebook_size + 1):
    ids = tok(f"[Sp{i}]", add_special_tokens=False)["input_ids"]
    if len(ids) != 1:
        raise RuntimeError(f"[Sp{i}] is not a single token; tokenizer mismatch.")
    speech_token_ids.append(ids[0])

# For mapping back: token_id -> code_index (0-based)
id2code = {tid: i for i, tid in enumerate(speech_token_ids)}
eos_id = tok.eos_token_id
allowed_ids = set(speech_token_ids + ([eos_id] if eos_id is not None else []))

# ---------- load prompt audio (16k), encode to codes ----------
wav16, sr = torchaudio.load("prompt_16k.wav")     # mono
assert sr == 16000
if wav16.size(0) > 1:
    wav16 = wav16.mean(dim=0, keepdim=True)

# Resample to 24k before codec
wav24 = torchaudio.functional.resample(wav16, orig_freq=16000, new_freq=24000).to(device)

bw = torch.tensor([0], device=device)
feats, codes = wt.encode_infer(wav24, bandwidth_id=bw)
# Normalize shapes to [T] list[int]
if codes.dim() == 3:
    codes = codes.squeeze(0)
    codes = codes[0] if codes.size(0) > 1 else codes.squeeze(0)
elif codes.dim() == 2:
    codes = codes.squeeze(0)
codes_list = codes.long().tolist()  # each in [0..4095]

# ---------- optional: decode round-trip for the stitched prefix ----------
recon24 = wt.decode(feats, bandwidth_id=bw)
if recon24.dim() == 3:
    recon24 = recon24.squeeze(0)

# ---------- build LM prefix string ----------
prefix_text = speech_prefix + "".join(f"[Sp{c+1}]" for c in codes_list)
enc = tok(prefix_text, return_tensors="pt")
input_ids = enc["input_ids"].to(device)
attn_mask = enc.get("attention_mask", None)
if attn_mask is not None:
    attn_mask = attn_mask.to(device)

# ---------- generate continuation (about N seconds) ----------
seconds = 3.0
max_new_tokens = max(1, int(round(seconds * codes_per_second)))

lp = LogitsProcessorList([SpeechOnlyLogitsProcessor(allowed_ids)])

gen = lm.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=max_new_tokens,
    do_sample=True,
    temperature=0.7,
    top_p=0.9,
    repetition_penalty=1.1,
    eos_token_id=eos_id,
    pad_token_id=(tok.pad_token_id if tok.pad_token_id is not None else eos_id),
    logits_processor=lp,
)

# Strip prefix and EOS
gen_tail = gen[0][input_ids.size(1):].tolist()
if eos_id is not None and eos_id in gen_tail:
    gen_tail = gen_tail[:gen_tail.index(eos_id)]

# Map token ids -> code indices (0-based)
new_codes = [id2code[t] for t in gen_tail if t in id2code]

# (Nice) keep last ~1s of prompt codes to avoid a hard seam
keep_sec = 1.0
keep = max(0, int(round(keep_sec * codes_per_second)))
tail_codes = codes_list[-keep:] if keep > 0 else []
decode_codes = tail_codes + new_codes

# Decode to audio (24 kHz)
tok_tensor = torch.tensor(decode_codes, dtype=torch.long, device=device).view(1,1,-1)
cont24 = wt.decode(wt.codes_to_features(tok_tensor), bandwidth_id=bw)
if cont24.dim() == 3:
    cont24 = cont24.squeeze(0)

# Stitch with crossfade
stitched = equal_power_crossfade(recon24, cont24, fade_ms=60, sr=24000)

# Save files
sf.write("recon_24k.wav",     recon24.squeeze(0).detach().cpu().numpy(), 24000)
sf.write("continuation.wav",  cont24.squeeze(0).detach().cpu().numpy(),  24000)
sf.write("stitched_24k.wav",  stitched.squeeze(0).detach().cpu().numpy(),24000)
print("Wrote recon_24k.wav, continuation.wav, stitched_24k.wav")